%Find the split point (in x or y dimension) that best separates the dataset into two classes function [ dim, sval ] = k_find_split( patterns, targets ) targets = targets(:); n_b = sum(targets==0); % number of 'blue' samples n_g = sum(targets==1); % number of 'green' samples numDim = size(patterns,2); nSamples = size(patterns,1); M = zeros(nSamples-1,numDim); %Find impurity in first node; set up delta_i (change in impurity) matrix if nSamples == 0 p_b = 0; p_g = 0; else p_b = n_b/nSamples; p_g = n_g/nSamples; end i_N = -((p_b*log2(p_b+eps))+(p_g*log2(p_g+eps))); delta_i = zeros(nSamples,numDim); for di=1:numDim, %Sort data in current dimension [sP,inds]=sort(patterns(:,di)); sT = targets(inds); %Compute midpoints and put into a matrix, 'M' for m = 1:nSamples-1 M(m,di) = (sP(m) + sP(m+1))./2; end %Repeat last midpoint to make dimensions of M consistent with dimensions of %sP M(nSamples,di) = M(nSamples-1,di); %Find delta_i for m = 1:nSamples left = find(sP <= M(m,di)); targets_l = sT(left); left = sP(left,:); right = find(sP > M(m,di)); targets_r = sT(right); right = sP(right,:); %figure; plot( left(:,1), left(:,2), 'xg' ); hold on; %plot( right(:,1), right(:,2), 'db' ); if size(targets_l,1) == 0 p_l_b = 0; p_l_g = 0; else p_l_b = size(find(targets_l == 0),1)/size(targets_l,1); p_l_g = size(find(targets_l == 1),1)/size(targets_l,1); end if size(targets_r,1) == 0 p_r_b = 0; p_r_g = 0; else p_r_b = size(find(targets_r == 0),1)/size(targets_r,1); p_r_g = size(find(targets_r == 1),1)/size(targets_r,1); end i_N_l = -((p_l_b*log2(p_l_b+eps))+(p_l_g*log2(p_l_g+eps))); i_N_r = -((p_r_b*log2(p_r_b+eps))+(p_r_g*log2(p_r_g+eps))); p_l = size(left,1)/size(sP,1); p_r = size(right,1)/size(sP,1); delta_i(m,di) = i_N - ( (p_l*i_N_l)+(p_r*i_N_r) ); end %figure; plot(M(:,di),delta_i(:,di)) end % figure; hx = plot( M(:,1), delta_i(:,1), 'rx' ); hold on; % hy = plot( M(:,2), delta_i(:,2), 'bo' ); legend( [ hx hy ], { 'change in x imp.', 'change in y imp.' } ); mxDeltaI = -Inf; mxSP = NaN; for di=1:numDim, [mx,ind] = max(delta_i(:,di)); if( mx > mxDeltaI ) mxDeltaI = mx; dim = di; sval = M(ind,di); end end