This example shows how to use fitcauto
to automatically try a selection of classification model types with different hyperparameter values, given training predictor and response data. The function uses Bayesian optimization to select models and their hyperparameter values, and computes the cross-validation classification error for each model. After the optimization is complete, fitcauto
returns the model, trained on the entire data set, that is expected to best classify new data. Check the model performance on test data.
This example uses the 1994 census data stored in census1994.mat
. The data set consists of demographic information from the US Census Bureau that can be used to predict whether an individual makes over $50,000 per year.
Load the sample data census1994
, which contains the training data adultdata
and the test data adulttest
. Preview the first few rows of the training data set.
load census1994
head(adultdata)
ans=8×15 table
age workClass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country salary
___ ________________ __________ _________ _____________ _____________________ _________________ _____________ _____ ______ ____________ ____________ ______________ ______________ ______
39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
38 Private 2.1565e+05 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
53 Private 2.3472e+05 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
28 Private 3.3841e+05 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
37 Private 2.8458e+05 Masters 14 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States <=50K
49 Private 1.6019e+05 9th 5 Married-spouse-absent Other-service Not-in-family Black Female 0 0 16 Jamaica <=50K
52 Self-emp-not-inc 2.0964e+05 HS-grad 9 Married-civ-spouse Exec-managerial Husband White Male 0 0 45 United-States >50K
Each row contains the demographic information for one adult. The last column salary
shows whether a person has a salary less than or equal to $50,000 per year or greater than $50,000 per year.
Use fitcauto
to automatically find an appropriate classifier for the data in adultdata
. Set the observation weights, and specify to run the Bayesian optimization in parallel, which requires Parallel Computing Toolbox™. Due to the nonreproducibility of parallel timing, parallel Bayesian optimization does not necessarily yield reproducible results.
Because of the complexity of the optimization, this process can take some time, especially for larger data sets. By default, fitcauto
provides a plot of the optimization and an iterative display of the optimization results. For more information on how to interpret these results, see Verbose Display.
options = struct('UseParallel',true); [mdl,results] = fitcauto(adultdata,'salary','Weights','fnlwgt', ... 'HyperparameterOptimizationOptions',options);
Warning: It is recommended that you first standardize all numeric predictors when optimizing the Naive Bayes 'Width' parameter. Ignore this warning if you have done that.
Starting parallel pool (parpool) using the 'local' profile ... Connected to the parallel pool (number of workers: 6). Copying objective function to workers... Done copying objective function to workers.
Learner types to explore: ensemble, nb, tree Total iterations (MaxObjectiveEvaluations): 90 Total time (MaxTime): Inf
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 1 | 5 | Accept | 0.23856 | 2.6087 | 0.1821 | 0.20863 | tree | MinLeafSize: 10889 | | 2 | 5 | Best | 0.1821 | 2.5782 | 0.1821 | 0.20863 | tree | MinLeafSize: 4990 |
| 3 | 6 | Best | 0.14971 | 50.17 | 0.14971 | 0.14971 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.41891 |
| 4 | 6 | Accept | 0.17743 | 67.443 | 0.14971 | 0.14971 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 225 | | | | | | | | | | MinLeafSize: 3144 |
| 5 | 6 | Accept | 0.17705 | 68.174 | 0.14971 | 0.14971 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 225 | | | | | | | | | | MinLeafSize: 3144 |
| 6 | 6 | Accept | 0.1623 | 1.029 | 0.14971 | 0.15664 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 7 | 6 | Accept | 0.1745 | 1.2602 | 0.14971 | 0.15664 | tree | MinLeafSize: 1292 |
| 8 | 6 | Best | 0.14014 | 2.7929 | 0.14014 | 0.15664 | tree | MinLeafSize: 40 |
| 9 | 6 | Accept | 0.15209 | 98.452 | 0.14014 | 0.15664 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 239 | | | | | | | | | | MinLeafSize: 161 |
| 10 | 6 | Accept | 0.16032 | 99.589 | 0.14014 | 0.15807 | nb | DistributionNames: kernel | | | | | | | | | | Width: 9.0144 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 11 | 6 | Accept | 0.17807 | 1.446 | 0.14014 | 0.15807 | tree | MinLeafSize: 3849 |
| 12 | 6 | Accept | 0.16427 | 116.76 | 0.14014 | 0.15807 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 294 | | | | | | | | | | MinLeafSize: 963 |
| 13 | 6 | Accept | 0.15706 | 53.771 | 0.14014 | 0.15807 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 241 | | | | | | | | | | MinLeafSize: 6 |
| 14 | 6 | Accept | 0.14443 | 1.8727 | 0.14014 | 0.15807 | tree | MinLeafSize: 229 |
| 15 | 6 | Accept | 0.14941 | 49.309 | 0.14014 | 0.15523 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.26346 |
| 16 | 6 | Accept | 0.15139 | 105.88 | 0.14014 | 0.15523 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 257 | | | | | | | | | | MinLeafSize: 2 |
| 17 | 6 | Accept | 0.15434 | 3.5748 | 0.14014 | 0.15262 | tree | MinLeafSize: 8 |
| 18 | 6 | Accept | 0.14968 | 49.046 | 0.14014 | 0.15262 | nb | DistributionNames: kernel | | | | | | | | | | Width: 0.40347 |
| 19 | 6 | Accept | 0.14421 | 2.9271 | 0.14014 | 0.14824 | tree | MinLeafSize: 23 |
| 20 | 6 | Accept | 0.14923 | 1.6667 | 0.14014 | 0.14673 | tree | MinLeafSize: 354 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 21 | 6 | Best | 0.13996 | 3.0956 | 0.13996 | 0.145 | tree | MinLeafSize: 39 |
| 22 | 6 | Accept | 0.15111 | 83.972 | 0.13996 | 0.145 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 208 | | | | | | | | | | MinLeafSize: 14 |
| 23 | 6 | Accept | 0.19568 | 98.149 | 0.13996 | 0.145 | nb | DistributionNames: kernel | | | | | | | | | | Width: 8643.5 |
| 24 | 6 | Accept | 0.23856 | 0.49289 | 0.13996 | 0.14508 | tree | MinLeafSize: 8801 |
| 25 | 6 | Accept | 0.1623 | 0.606 | 0.13996 | 0.14508 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 26 | 6 | Accept | 0.1623 | 0.38862 | 0.13996 | 0.14508 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 27 | 6 | Accept | 0.1623 | 0.34002 | 0.13996 | 0.14508 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 28 | 6 | Accept | 0.15847 | 49.535 | 0.13996 | 0.14508 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 212 | | | | | | | | | | MinLeafSize: 376 |
| 29 | 6 | Accept | 0.15639 | 63.923 | 0.13996 | 0.14508 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 288 | | | | | | | | | | MinLeafSize: 1774 |
| 30 | 6 | Accept | 0.1623 | 0.79402 | 0.13996 | 0.14508 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 31 | 6 | Accept | 0.1453 | 3.0022 | 0.13996 | 0.14429 | tree | MinLeafSize: 16 |
| 32 | 6 | Accept | 0.1623 | 0.32634 | 0.13996 | 0.14429 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 33 | 6 | Accept | 0.23856 | 57.127 | 0.13996 | 0.14429 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 250 | | | | | | | | | | MinLeafSize: 6630 |
| 34 | 6 | Accept | 0.23856 | 0.52406 | 0.13996 | 0.14436 | tree | MinLeafSize: 9982 |
| 35 | 6 | Accept | 0.15781 | 52.097 | 0.13996 | 0.14436 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 224 | | | | | | | | | | MinLeafSize: 440 |
| 36 | 6 | Accept | 0.15434 | 3.4554 | 0.13996 | 0.14431 | tree | MinLeafSize: 8 |
| 37 | 6 | Accept | 0.15669 | 54.961 | 0.13996 | 0.14431 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 252 | | | | | | | | | | MinLeafSize: 4 |
| 38 | 6 | Accept | 0.16986 | 1.5157 | 0.13996 | 0.14367 | tree | MinLeafSize: 791 |
| 39 | 6 | Accept | 0.17708 | 55.2 | 0.13996 | 0.14367 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 220 | | | | | | | | | | MinLeafSize: 4265 |
| 40 | 6 | Accept | 0.16315 | 86.272 | 0.13996 | 0.14367 | ensemble | Method: Bag | | | | | | | | | | NumLearningCycles: 223 | | | | | | | | | | MinLeafSize: 743 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 41 | 6 | Accept | 0.14449 | 1.9856 | 0.13996 | 0.14391 | tree | MinLeafSize: 235 |
| 42 | 5 | Accept | 0.167 | 96.272 | 0.13996 | 0.14287 | nb | DistributionNames: kernel | | | | | | | | | | Width: 18.272 | | 43 | 5 | Accept | 0.14377 | 2.0295 | 0.13996 | 0.14287 | tree | MinLeafSize: 193 |
| 44 | 4 | Accept | 0.15699 | 56.595 | 0.13996 | 0.14152 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 272 | | | | | | | | | | MinLeafSize: 997 | | 45 | 4 | Accept | 0.14013 | 2.4702 | 0.13996 | 0.14152 | tree | MinLeafSize: 36 |
| 46 | 4 | Accept | 0.14344 | 1.7418 | 0.13996 | 0.14185 | tree | MinLeafSize: 189 |
| 47 | 4 | Accept | 0.15677 | 1.425 | 0.13996 | 0.14176 | tree | MinLeafSize: 444 |
| 48 | 3 | Accept | 0.15259 | 83.025 | 0.13996 | 0.14182 | nb | DistributionNames: kernel | | | | | | | | | | Width: 4.2748 | | 49 | 3 | Accept | 0.14368 | 1.6998 | 0.13996 | 0.14182 | tree | MinLeafSize: 182 |
| 50 | 6 | Accept | 0.1408 | 2.177 | 0.13996 | 0.1411 | tree | MinLeafSize: 38 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 51 | 4 | Accept | 0.13996 | 2.5601 | 0.13996 | 0.14072 | tree | MinLeafSize: 39 | | 52 | 4 | Accept | 0.17244 | 1.1659 | 0.13996 | 0.14072 | tree | MinLeafSize: 1038 | | 53 | 4 | Accept | 0.17797 | 0.98188 | 0.13996 | 0.14072 | tree | MinLeafSize: 1818 |
| 54 | 3 | Accept | 0.15643 | 57.195 | 0.13996 | 0.14051 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 287 | | | | | | | | | | MinLeafSize: 30 | | 55 | 3 | Accept | 0.14048 | 2.3411 | 0.13996 | 0.14051 | tree | MinLeafSize: 37 |
| 56 | 6 | Accept | 0.1408 | 2.1605 | 0.13996 | 0.14054 | tree | MinLeafSize: 38 |
| 57 | 5 | Accept | 0.13996 | 2.5887 | 0.13996 | 0.14054 | tree | MinLeafSize: 39 | | 58 | 5 | Accept | 0.1623 | 0.46796 | 0.13996 | 0.14054 | nb | DistributionNames: normal | | | | | | | | | | Width: NaN |
| 59 | 5 | Accept | 0.13996 | 2.4424 | 0.13996 | 0.14017 | tree | MinLeafSize: 39 |
| 60 | 5 | Accept | 0.14013 | 2.5432 | 0.13996 | 0.14013 | tree | MinLeafSize: 36 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 61 | 5 | Accept | 0.16408 | 1.4549 | 0.13996 | 0.14011 | tree | MinLeafSize: 585 |
| 62 | 5 | Accept | 0.14061 | 2.1993 | 0.13996 | 0.14017 | tree | MinLeafSize: 85 |
| 63 | 5 | Accept | 0.14314 | 2.6421 | 0.13996 | 0.14013 | tree | MinLeafSize: 29 |
| 64 | 5 | Accept | 0.14001 | 2.2084 | 0.13996 | 0.14008 | tree | MinLeafSize: 86 |
| 65 | 5 | Accept | 0.14133 | 2.5565 | 0.13996 | 0.14007 | tree | MinLeafSize: 35 |
| 66 | 4 | Accept | 0.15943 | 41.891 | 0.13996 | 0.14011 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 201 | | | | | | | | | | MinLeafSize: 42 | | 67 | 4 | Accept | 0.14256 | 2.1949 | 0.13996 | 0.14011 | tree | MinLeafSize: 125 |
| 68 | 3 | Accept | 0.15631 | 57.502 | 0.13908 | 0.14002 | ensemble | Method: LogitBoost | | | | | | | | | | NumLearningCycles: 291 | | | | | | | | | | MinLeafSize: 1162 | | 69 | 3 | Best | 0.13908 | 2.0902 | 0.13908 | 0.14002 | tree | MinLeafSize: 98 |
| 70 | 6 | Accept | 0.14061 | 1.8764 | 0.13908 | 0.14006 | tree | MinLeafSize: 85 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 71 | 5 | Accept | 0.14143 | 2.169 | 0.13908 | 0.14002 | tree | MinLeafSize: 109 | | 72 | 5 | Accept | 0.14472 | 1.9738 | 0.13908 | 0.14002 | tree | MinLeafSize: 247 |
| 73 | 5 | Accept | 0.14193 | 2.5079 | 0.13908 | 0.14006 | tree | MinLeafSize: 34 |
| 74 | 5 | Accept | 0.1507 | 1.5583 | 0.13908 | 0.14007 | tree | MinLeafSize: 392 |
| 75 | 5 | Accept | 0.13908 | 2.0123 | 0.13908 | 0.13994 | tree | MinLeafSize: 98 |
| 76 | 5 | Accept | 0.14124 | 2.0292 | 0.13908 | 0.13994 | tree | MinLeafSize: 111 |
| 77 | 5 | Accept | 0.13917 | 2.0368 | 0.13908 | 0.13969 | tree | MinLeafSize: 97 |
| 78 | 5 | Accept | 0.14032 | 2.0503 | 0.13908 | 0.13957 | tree | MinLeafSize: 103 |
| 79 | 4 | Accept | 0.18035 | 87.193 | 0.13908 | 0.13949 | nb | DistributionNames: kernel | | | | | | | | | | Width: 483.03 | | 80 | 4 | Accept | 0.13917 | 2.1151 | 0.13908 | 0.13949 | tree | MinLeafSize: 97 |
|===========================================================================================================================================| | Iter | Active | Eval | Validation | Time for training | Observed min | Estimated min | Learner | Hyperparameter: Value | | | workers | result | loss | & validation (sec)| validation loss | validation loss | | | |===========================================================================================================================================| | 81 | 4 | Best | 0.13907 | 1.9149 | 0.13907 | 0.13934 | tree | MinLeafSize: 96 |
| 82 | 4 | Accept | 0.13917 | 1.9271 | 0.13907 | 0.13927 | tree | MinLeafSize: 97 |
| 83 | 4 | Accept | 0.13917 | 1.9196 | 0.13907 | 0.13924 | tree | MinLeafSize: 97 |
| 84 | 4 | Accept | 0.1396 | 1.9448 | 0.13907 | 0.13923 | tree | MinLeafSize: 91 |
| 85 | 4 | Accept | 0.13908 | 1.9816 | 0.13907 | 0.13921 | tree | MinLeafSize: 98 |
| 86 | 4 | Accept | 0.14951 | 66.139 | 0.13907 | 0.1392 | nb | DistributionNames: kernel | | | | | | | | | | Width: 1.9412 | | 87 | 4 | Accept | 0.18157 | 0.54864 | 0.13907 | 0.1392 | tree | MinLeafSize: 4072 |
| 88 | 5 | Accept | 0.13954 | 2.1344 | 0.13907 | 0.13922 | tree | MinLeafSize: 100 |
| 89 | 5 | Accept | 0.15815 | 1.5244 | 0.13907 | 0.1392 | tree | MinLeafSize: 492 |
| 90 | 5 | Accept | 0.14765 | 1.7524 | 0.13907 | 0.13922 | tree | MinLeafSize: 315 |
__________________________________________________________ Optimization completed. Total iterations: 90 Total elapsed time: 478.6155 seconds Total time for training and validation: 2024.0417 seconds Best observed learner is a tree model with: MinLeafSize: 96 Observed validation loss: 0.13907 Time for training and validation: 1.9149 seconds Best estimated learner (returned model) is a tree model with: MinLeafSize: 97 Estimated validation loss: 0.13922 Estimated time for training and validation: 2.0003 seconds Documentation for fitcauto display
The final model returned by fitcauto
corresponds to the best estimated learner. Before returning the model, the function retrains it using the entire training data (adultdata
), the listed Learner
(or model) type, and the displayed hyperparameter values.
Evaluate the performance of the returned model mdl
on the test set adulttest
by using a confusion matrix and a receiver operating characteristic (ROC) curve.
Find the predicted labels and score values for the test set.
[labels,scores] = predict(mdl,adulttest);
Create a confusion matrix from the test set results. The diagonal elements indicate the number of correctly classified instances of a given class. The off-diagonal elements are instances of misclassified observations.
confusionchart(adulttest.salary,labels)
Compute the test set classification accuracy. accuracy
is the percentage of correctly classified test set observations.
accuracy = (1-loss(mdl,adulttest,'salary'))*100
accuracy = 85.4486
To plot the ROC curve for the score values corresponding to the label '<=50K'
, find the column of scores
that corresponds to that label. The column order of scores
matches the order of the classes in the trained model.
mdl.ClassNames
ans = 2×1 categorical
<=50K
>50K
Because '<=50K'
is listed first, the first column of scores
corresponds to that label.
Plot the ROC curve, and compute the area under the curve (AUC). The ROC curve shows the true positive rate versus the false positive rate for different thresholds of the classifier output. For a perfect classifier, whose true positive rate is always 1 regardless of the threshold, AUC = 1. For a binary classifier that randomly assigns observations to classes, AUC = 0.5. A large AUC value (close to 1) indicates good classifier performance.
[X,Y,~,AUC] = perfcurve(adulttest.salary,scores(:,1),'<=50K'); plot(X,Y) title('ROC Curve') xlabel('False Positive Rate') ylabel('True Positive Rate')
AUC
AUC = 0.8848
Based on the accuracy and AUC values, the classifier performs well on the test data.
BayesianOptimization
| confusionchart
| fitcauto
| perfcurve