two_class_LDA_with_optimal_cut_point <- function(X,labels,nCP=100,doPlots=FALSE){ # # R code to compute two class LDA classification but where the cut point is # specified to minimize the prediction error over the training set. # # See the section entitled Linear discriminant Analysis where this is suggesteed # in Chapter 4 from the book ESLII # # Inputs: # X = training matrix of size number_of_samples by number_of_features # labels = training labels matrix of true classifications with indices 1 and 2 # # Written by: # -- # John L. Weatherwax 2009-04-21 # # email: wax@alum.mit.edu # # Please send comments and especially bug reports to the # above email address. # #----- uLabels = sort(unique(labels)) K = length(uLabels) # the number of classes ... we can only have two stopifnot(K==2) N = dim( X )[1] # the number of samples p = dim( X )[2] # the dimension of the feature space # TODO: Perform leave one out cross validation: # # Take out a point, perform all of the steps below, and classify that point and count the number of times its classified correctly. # # estimate the pooled covariance matrix: sigmaHat = cov(X) # estimate the class specific means: # inds = labels == uLabels[1] mean_1 = mean( X[ inds, ] ) inds = labels == uLabels[2] mean_2 = mean( X [ inds, ] ) # get the range of possible cut points: # A = as.matrix( solve( sigmaHat, mean_2 - mean_1 ) ) discriminant = as.matrix( X ) %*% A cpMin = min( discriminant ) cpMax = max( discriminant ) cpRange = cpMin + ( ( cpMax - cpMin ) / nCP ) * ( 0 : nCP ) errRates = c() for( ii in 1:(nCP+1) ){ predictedLabels = mat.or.vec( N, 1 ) + uLabels[1] # everything starts predicted as class #1 thresh = cpRange[ii] class2 = discriminant > thresh # but these would be classified as class #2 predictedLabels[class2] = uLabels[2] er = sum( abs( labels - predictedLabels )!=0 )/length(labels) errRates = c( errRates, er ) } if( doPlots ){ plot( cpRange, errRates ) } smallestErrorNIndex = which.min( errRates ) optThreshold = cpRange[ smallestErrorNIndex ] optError = errRates[ smallestErrorNIndex ] return( list(A,optThreshold,optError) ) }