kfoldfun

Cross validate function

Description

example

vals = kfoldfun(CVMdl,fun) cross validates the function fun by applying fun to the data stored in the cross-validated model CVMdl. You must pass fun as a function handle.

Input Arguments

expand all

Cross-validated model, specified as a ClassificationPartitionedECOC model, ClassificationPartitionedEnsemble model, or a ClassificationPartitionedModel model.

Cross-validated function, specified as a function handle. fun has the syntax

testvals = fun(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
  • CMP is a compact model stored in one element of the CVMdl.Trained property.

  • Xtrain is the training matrix of predictor values.

  • Ytrain is the training array of response values.

  • Wtrain are the training weights for observations.

  • Xtest and Ytest are the test data, with associated weights Wtest.

  • The returned value testvals needs the same size across all folds.

Data Types: function_handle

Output Arguments

expand all

Cross-validation results, returned as an numeric matrix. vals is the arrays of testvals output, concatenated vertically over all folds. For example, if testvals from every fold is a numeric vector of length N, kfoldfun returns a KFold-by-N numeric matrix with one row per fold.

Data Types: double

Examples

expand all

Train a classification tree classifier, and then cross validate it using a custom k-fold loss function.

Load Fisher’s iris data set.

load fisheriris

Train a classification tree classifier.

Mdl = fitctree(meas,species);

Mdl is a ClassificationTree model.

Cross validate Mdl using the default 10-fold cross validation. Compute the classification error (proportion of misclassified observations) for the out-of-fold observations.

rng(1); % For reproducibility
CVMdl = crossval(Mdl);
L = kfoldLoss(CVMdl)
L =

    0.0467

Examine the result when the cost of misclassifying a flower as 'versicolor' is 10, and any other error is 1. Write a function called noversicolor.m that attributes a cost of 1 for misclassification, but 10 for misclassifying a flower as versicolor, and save it on your MATLAB® path.

function averageCost = noversicolor(CMP,Xtrain,Ytrain,Wtrain,Xtest,Ytest,Wtest)
%noversicolor Example custom cross-validation function
%   Attributes a cost of 10 for misclassifying versicolor irises, and 1 for
%   the other irises.  This example function requires the |fisheriris| data
%   set.
Ypredict = predict(CMP,Xtest);
misclassified = not(strcmp(Ypredict,Ytest)); % Different result
classifiedAsVersicolor = strcmp(Ypredict,'versicolor'); % Index of bad decisions
cost = sum(misclassified) + ...
    9*sum(misclassified & classifiedAsVersicolor); % Total differences
averageCost = cost/numel(Ytest); % Average error
end


Compute the mean misclassification error with the noversicolor cost.

mean(kfoldfun(CVMdl,@noversicolor))
ans =

    0.2267