% % Written by: % -- % John L. Weatherwax 2007-07-01 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % %----- close all; clc; clear; rand('seed',0); randn('seed',0); mu1 = [ 1, 1 ].'; mu2 = [ 1.5, 1.5 ].'; sigmasSquared = 0.2; d = size(mu1,1); nFeats = 10000; X1 = mvnrnd( mu1, sigmasSquared*eye(d), nFeats ); X2 = mvnrnd( mu2, sigmasSquared*eye(d), nFeats ); h1 = plot( X1(:,1), X1(:,2), '.b' ); hold on; h2 = plot( X2(:,1), X2(:,2), '.r' ); hold on; legend( [h1,h2], {'class 1', 'class 2'} ); mean_diff = mu1 - mu2; X = [ X1; X2 ]; labels = [ ones(nFeats,1); 2*ones(nFeats,1) ]; % Part (a): classify each of the vectors in X1 and X2 using the algebraically simplified notation: % rhs = 0.5 * ( dot(mu1,mu1) - dot(mu2,mu2) ); lhs = mean_diff' * X'; % If lhs > rhs we are selecting this sample from class #1 otherwise from class #2: % class_decision = lhs > rhs; choosen_class = zeros(2*nFeats,1); choosen_class(find(class_decision==1)) = 1; choosen_class(find(class_decision~=1)) = 2; P_correct = sum(choosen_class == labels)/(2*nFeats); P_error = 1 - P_correct; % Calculate the optimal Bayes error rate (using the results from Problem~2.9): % addpath('../../../Duda_Hart_Stork/Code/Chapter2/ComputerExercises'); dm = mahalanobis(mu1,mu2,sigmasSquared*eye(d)); P_B = 1 - normcdf( 0.5 * dm ); fprintf('empirical P_e= %10.6f; analytic Bayes P_e= %10.6f\n',P_error,P_B); % Part (b): classify each of the vectors in X1 and X2 (using the algebraically simplified notation): % rhs = 0.5 * ( dot(mu1,mu1) - dot(mu2,mu2) ) + sigmasSquared * log(2); lhs = mean_diff' * X'; % If lhs > rhs we are selecting this sample from class #1 otherwise from class #2: % class_decision = lhs > rhs; choosen_class = zeros(2*nFeats,1); choosen_class(find(class_decision==1)) = 1; choosen_class(find(class_decision~=1)) = 2; P_correct = sum(choosen_class == labels)/(2*nFeats); P_error = 1 - P_correct; % extract the expected loss using this classifier: % L_12 = sum( choosen_class(1:nFeats)~=1 ); L_21 = sum( choosen_class(nFeats+1:end)~=2 ); r_hat = ( L_12 + 0.5 * L_21 )/(2*nFeats); fprintf('P_correct= %10.6f; r_hat= %10.6f\n',P_correct,r_hat);