% % Problem EPage 393 % % Independent test-sample to pick the classification tree % % Examples 9.12 and 9.13 (epage 371) implements this ... % % Written by: % -- % John L. Weatherwax 2008-02-20 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % %----- clear all; close all; clc; % sample_discrete.m addpath( genpath( '/nfs/stateot/weatherw/Software/FullBNT-1.0.4/KPMstats/' ) ); addpath('../../Code/CSTool'); load flea; [nc, d] = size(conc); [nhi,d] = size(heik); [nhp,d] = size(hept); Z = zscore( [conc;heik;hept] ); conc = Z(1:nc,:); heik = Z(nc+1:nc+nhi,:); hept = Z(nc+nhi+1:end,:); data1 = conc; n1 = nc; data2 = heik; n2 = nhi; data3 = hept; n3 = nhp; % these are the inputs to function - csgrowc. maxn = 5; % maximum number points in the terminal nodes clas = [1 2 3]; % class labels: 1 = conc; 2 = heik; 3 = hept pies = [1 1 1]/3; % optional prior probabilities Nk = [nc, nhi, nhp]; % number in each class X = [data1,ones(n1,1);data2,2*ones(n2,1);data3,3*ones(n3,1)]; tree = csgrowc(X,maxn,clas,Nk,pies); figure; csplotreec(tree); saveas( gcf, '../../WriteUp/Graphics/Chapter9/prob_9_13_initial_tree', 'epsc' ); treeseq = csprunec(tree); K = length(treeseq); % Find the sequence of alphas. alpha = zeros(1,K); % Note that the root node corresponds to position K i.e. the last one in the sequence for i = 1:K alpha(i) = treeseq{i}.alpha; end % estimate a three term multivariate finite mixture model on this data: % n_classes = 3; muin = zeros(n_classes,d); muin(1,:) = mean(conc); muin(2,:) = mean(heik); muin(3,:) = mean(hept); muin = muin.'; varin = zeros(d,d,n_classes); varin(:,:,1) = cov(conc); varin(:,:,2) = cov(heik); varin(:,:,3) = cov(hept); piesin = [1,1,1]/3; max_it = 100; tol = 1e-3; [pies,mus,vars] = csfinmix( X(:,1:end-1), muin, varin, piesin, max_it, tol ); % generate an indpendent sample from this finite mixture: % n_is = 100; comp_pick = sample_discrete(pies,1,n_is); % <- a random draw for each class ... n1 = length(find(comp_pick==1)); n2 = length(find(comp_pick==2)); n3 = length(find(comp_pick==3)); n = n1+n2+n3; data1 = mvnrnd(mus(:,1),squeeze(vars(:,:,1)),n1); data2 = mvnrnd(mus(:,2),squeeze(vars(:,:,2)),n2); data3 = mvnrnd(mus(:,3),squeeze(vars(:,:,3)),n3); % Now check each prunned trees using our independent test cases in data1, data2, and data3. % Keep track of the ones missclassified. K = length(treeseq); Rk = zeros(1,K-1); % we do not check the root for k = 1:K-1 nmis = 0; treek = treeseq{k}; % loop through the cases from class 1 for i = 1:n1 [clas,pclass,node]=cstreec(data1(i,:),treek); if clas ~= 1 nmis = nmis+1; % misclassified end end % Loop through the cases from class 2 for i = 1:n2 [clas,pclass,node]=cstreec(data2(i,:),treek); if clas ~= 2 nmis = nmis+1; % misclassified end end % Loop through the cases from class 3 for i = 1:n3 [clas,pclass,node]=cstreec(data3(i,:),treek); if clas ~= 3 nmis = nmis+1; % misclassified end end Rk(k) = nmis/n; end % Find the minimum Rk. [mrk,ind]=min(Rk); % The tree T_1 corresponds to the minimum Rk. % Now find the se for that one. semrk = sqrt(mrk*(1-mrk)/n); % We add that to min(Rk). Rk2 = mrk+semrk; % the tree that is just under this value of 0.0955 is the {\em third} ind = 3; % extract this tree and plot: % best_tree = treeseq{ind}; figure; csplotreec(best_tree); saveas( gcf, '../../WriteUp/Graphics/Chapter9/prob_9_13_pruned_tree', 'epsc' ); return;