matlabでsvm(Support Vector Machine)を実現する方法

svmは非線形に対応した判別器で、良い識別結果を出すことが知られています。

参考コード:

function model = svm()
[x,y] = getGaosiData();
[n,dim] = size(x);
C = 10
plable = find(y==1);
nlable = find(y==-1);
plen = length(plable);
nlen = length(nlable);

options = optimset; % Optionsはアルゴリズムパラメータを制御
options.LargeScale = 'off’;
options.Display = 'off’;
% H = (y’*y).*(x*x’);
H = (y’*y).*kernel(x,x,kerType);
% f = -ones(n,1);
f = cat(1,zeros(plen,1),-ones(nlen,1));
A = [];
b = [];
Aeq = y;
beq = 0;
lb = zeros(n,1);
ub = C*ones(n,1);
a0 = zeros(n,1); % a0は近似の値
[a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options);
epsilon = 1e-8;
display(a);
sv_label = find(abs(a)>epsilon); %0<a<a(max)xはベクトルをサポート
a = a(sv_label);
Xsv = x(sv_label,:);
Ysv = y(sv_label)
svnum = length(sv_label);
model.a = a;
model.Xsv = Xsv;
model.Ysv = Ysv;

model.svnum = svnum;
num = length(Ysv);

W = zeros(1,dim);
for i = 1:num
W = W+ a(i,1)*Ysv(i)*Xsv(i,:);
end
model.W = W;
%%%bを解ける
py_label = find(Ysv==1);
pa = a(py_label);
pXsv = Xsv(py_label,:);
pYsv = Ysv(py_label);
pnum = length(py_label)

% b = 0;
% for i = 1:pnum
% tmp =0;
% for j=1:num
% tmp = tmp+a(j,1)*Ysv(j)*(Xsv(j,:)*pXsv(i,:)’);
% end
% b = b -tmp;
% end
% b = b/pnum;
tmp = a’.*Ysv*kernel(Xsv,pXsv,kerType);
% tmp = a’.*Ysv*(Xsv*pXsv’);
b = -mean(tmp)
model.b = b;
end

function K = kernel(X,Y,type)
%X *次元数
switch type
case 'linear’
K = X*Y’;
case 'rbf’
gamma = 5;
gamma = gamma*gamma;
XX = sum(X.*X,2);
YY = sum(Y.*Y,2);
XY = X*Y’;
K = abs(repmat(XX,[1 size(YY,1)]) + repmat(YY’,[size(XX,1) 1]) – 2*XY);
K = exp(-K./gamma);
end
end

Development

Posted by arkgame