% % Problem EPage 393 % % Independent test-sample to pick the classification tree % % Examples 9.12 and 9.13 (epage 371) implements this ... % % Written by: % -- % John L. Weatherwax 2008-02-20 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % %----- clear all; close all; clc; addpath('../../Code/CSTool'); load bank; Z = zscore( [forge;genuine] ); forge = Z(1:100,:); genuine = Z(101:200,:); % remove some features that seem particularly informative: % ... so that this problem is more susceptible to over fitting ... % ... otherwise forge(:,6)=[]; genuine(:,6)=[]; rand('seed',0); inds = randperm(100); data1 = forge(inds(1:50),:); % <- split the data randomly in half ... to save an data2 = genuine(inds(1:50),:); n = 100; n1 = 50; n2 = 50; % these are the inputs to function - csgrowc. maxn = 5; % maximum number points in the terminal nodes clas = [1 2]; % class labels: 1 = forge; 2 = genuine pies = [0.5 0.5]; % optional prior probabilities Nk = [n1, n2]; % number in each class X = [data1,ones(n1,1);data2,2*ones(n2,1)]; tree = csgrowc(X,maxn,clas,Nk,pies); figure; csplotreec(tree); saveas( gcf, '../../WriteUp/Graphics/Chapter9/prob_9_9_initial_tree', 'epsc' ); treeseq = csprunec(tree); K = length(treeseq); % Find the sequence of alphas. alpha = zeros(1,K); % Note that the root node corresponds to position K i.e. the last one in the sequence for i = 1:K alpha(i) = treeseq{i}.alpha; end % create an independent test set using the other data points: inds2 = setdiff( 1:100, inds(1:50) ); data1 = forge(inds2,:); [n1,d]=size(data1); data2 = genuine(inds2,:); [n2,d]=size(data2); % Now check these trees using our independent test cases in data1 and data2. % Keep track of the ones missclassified. K = length(treeseq); Rk = zeros(1,K-1); % we do not check the root for k = 1:K-1 nmis = 0; treek = treeseq{k}; % loop through the cases from class 1 for i = 1:n1 [clas,pclass,node]=cstreec(data1(i,:),treek); if clas ~= 1 nmis = nmis+1; % misclassified end end % Loop through the cases from class 2 for i = 1:n2 [clas,pclass,node]=cstreec(data2(i,:),treek); if clas ~= 2 nmis = nmis+1; % misclassified end end Rk(k) = nmis/n; end % Find the minimum Rk. [mrk,ind]=min(Rk); % The tree T_1 corresponds to the minimum Rk. % Now find the se for that one. semrk = sqrt(mrk*(1-mrk)/n); % We add that to min(Rk). Rk2 = mrk+semrk; % extract this tree and plot: % best_tree = treeseq{ind}; figure; csplotreec(best_tree); saveas( gcf, '../../WriteUp/Graphics/Chapter9/prob_9_9_pruned_tree', 'epsc' ); return;