Hyperparameter tuning of Support Vector Machine in MATLAB
Let's create a code for hyper parameter optimization of SVM classifier using Harris hawks optimization algorithm.
What is Hyperparameter Optimization?
As the name indicates, hyperparameter optimization means, optimal selection of SVM parameters.
Harris Hawks algorithm is adapted to optimally choose the SVM parameters like kernel functions and box constraints.
Fitness function
The fitness function is the minimization of classification error. Our objective is to find the best parameter setting with minimum classification error.
function [err,svmModel]= fitness_fun(p)
global traindata trainlabel valdata vallabela
kernel = {'gaussian', 'polynomial','linear'};
op1=kernel{round(p(1))};
kernelScale = round(p(2));
boxx = round(p(3));
svmModel = fitcsvm(traindata, trainlabel, ...
'BoxConstraint', boxx, ...
'KernelFunction', op1, ...
'KernelScale', kernelScale, ...
'Standardize', true);
out=predict(svmModel,valdata);
accuracy=length(find(out==vallabela))/length(vallabela);
err=1-accuracy;
end |
Main code:
% clear the environmentclc;clear;close all;global traindata trainlabel valdata vallabela
%% Load dataset - ionoshpere (in-built dataset in MATLAB)
load ionosphere
dataX = X;
dataY = categorical(Y);
%% partition data into training,testing and validation
% 70%-training, 10% -validation, 20%- testing
[trainInd,valInd,testInd] = dividerand(numel(dataY),0.7,0.1,0.2);
% training data & label
traindata = dataX(trainInd,:);
trainlabel = dataY(trainInd,:);
% validation data & label
valdata = dataX(valInd,:);
vallabela = dataY(valInd,:);
% testing data & label
testdata = dataX(testInd,:);
testlabel = dataY(testInd,:);
%% optimise hyper-parameter
N=3; % Number of search agents
T=50; % Maximum number of iterations
fobj=@fitness_fun; %Name of the objective function
lb = [ 1 1 2 ]; % lower bound
ub = [ 3 30 3 ]; % upper bound
dim=3;
[SVM_model,Rabbit_Location,CNVG]=SVM_HHO(N,T,lb,ub,dim,fobj);
%% predict the output of test data
out=predict(SVM_model,testdata);
accuracy=length(find(out==testlabel))/length(testlabel);
fprintf('Accuracy of HHO optimised SVM is %d\n',accuracy)
%% plot convergence curve
figure;
plot(CNVG,'-ob','linewidth',2)
xlabel('Iterations');ylabel('objective value')
Comments
Post a Comment