% % problem is on epage 432, while the example is on epage 428. % % Written by: % -- % John L. Weatherwax 2008-02-20 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % %----- clear all; close all; clc; addpath('../../Code/CSTool'); load abrasion; X=x; % split the data in half randomly: % n = size(X,1); d = size(X,2); rand('seed',0); t = randperm(size(X,1)); inds_l = t(1:floor(n/2)); inds_r = t(ceil(n/2):end); nprime = ceil(n/2); X_train = X(inds_l,:); y_train = y(inds_l); X_test = X(inds_r,:); y_test = y(inds_r); clear X y; maxn = 5; % <- the maximal number of points to include before stopping tree growth: tree = csgrowr(X_train,y_train,maxn); % now prune: treeseq = cspruner(tree); % For each tree in the sequence, find the total regression mean squared error k = length(treeseq); msek = zeros(1,k); numnodes = zeros(1,k); for i=1:(k-1) err = zeros(1,nprime); t = treeseq{i}; for j=1:nprime [yhat,node]=cstreer(X_test(j,:),t); err(j) = (y_test(j)-yhat).^2; end [term,nt,imp]=getdata(t); % find the # of terminal nodes numnodes(i) = length(find(term==1)); % find the mean msek(i) = mean(err); end t = treeseq{k}; msek(k) = mean((y_test-t.node(1).yhat).^2); % Find the subtree corresponding to the minimum MSE [msemin,ind]=min(msek); minnode = numnodes(ind); % Find the standard error for THAT subtree. t0 = treeseq{ind}; for j=1:nprime [yhat,node]=cstreer(X_test(j,:),t0); err(j) = (y_test(j)-yhat).^4-msemin^2; end se = sqrt(sum(err)/nprime)/sqrt(nprime); % what subtree has a the least complexity with a value of mean square error less than msemin+se: msek msemin+se