% % Written by: % -- % John L. Weatherwax 2005-08-04 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % % Epage 106 % %----- close all; clc; clear; rand('seed',0); randn('seed',0); % Generate the required data: % mu1 = [ 1, 1 ].'; mu2 = [ 0, 0 ].'; sigmasSquared = 0.2; d = size(mu1,1); nFeats = 100; X1 = mvnrnd( mu1, sigmasSquared*eye(d), nFeats ); X2 = mvnrnd( mu2, sigmasSquared*eye(d), nFeats ); h1 = plot( X1(:,1), X1(:,2), '.b' ); hold on; h2 = plot( X2(:,1), X2(:,2), '.r' ); hold on; legend( [h1,h2], {'class 1', 'class 2'} ); data = [X1;X2]; inds_1 = 1:nFeats; inds_2 = (nFeats+1):(2*nFeats); delta_x = [ -1*ones(nFeats,1); +1*ones(nFeats,1) ]; N = size(data,1); rho = 0.7; % % Code begins: %--- % Append +1 to all data: data = [ data, ones(size(data,1),1) ]; l_extended = size(data,2); w_i = randn(size(data,2),1); % the initial weight vector w_s = w_i; % save it in the pocket h_s = 0; % history counter maxIters = 1000; Niters = 0; while( 1 ) % we assume the data is linearly seperable % Find the set J (the missclassified samples) with this weight vector: predicted_class = data * w_i; predicted_class(inds_2) = -predicted_class(inds_2); % negate the sign of class 2 objects Y = find( predicted_class < 0 ); % find the indices of misclassified vectors if( isempty(Y) ) % we are done ... everything is classified correctly! break; end delta_w = sum( data( Y, : ) .* repmat( delta_x( Y ), 1, l_extended ), 1 ).'; w_ip1 = w_i - rho * delta_w; % test and see how many vectors are classified correctly with this new weight vector: % predicted_class = data * w_ip1; predicted_class(inds_2) = -predicted_class(inds_2); % negate the sign of class 2 objects Y_complement = find( predicted_class > 0 ); % find the indices of the correctly classified vectors h = length(Y_complement); if( h > h_s ) w_s = w_ip1; h_s = h; end w_i = w_ip1; Niters = Niters + 1; if( Niters > maxIters ) fprintf('max number of iterations= %10d exceeded\n',maxIters); break end end % draw the decision boundary w_s^T x = 0 % x1_grid = linspace( min(data(:,1)), max(data(:,1)), 50 ); x2_db = ( -w_s(3) - w_s(1) * x1_grid ) / w_s(2); plot( x1_grid, x2_db, '-g' ); title('pocket algorithm computed decision line'); fn = ['chap_3_pocket_algo.eps']; saveas(gcf,['../../WriteUp/Graphics/Chapter3/',fn],'epsc');