# # Performce ordinary least squares on the spam data # # Written by: # -- # John L. Weatherwax 2009-04-21 # # email: wax@alum.mit.edu # # Please send comments and especially bug reports to the # above email address. # #----- source('load_spam_data.R') res = load_spam_data(trainingScale=TRUE,responseScale=FALSE) XTraining = res[[1]] XTesting = res[[2]] nrow = dim( XTraining )[1] p = dim( XTraining )[2] - 1 # the last column is the response D = XTraining[,1:p] # get the predictor data # Append a column of ones: # Dp = cbind( matrix(1,nrow,1), as.matrix( D ) ) response = XTraining[,p+1] library(MASS) betaHat = ginv( t(Dp) %*% Dp ) %*% t(Dp) %*% as.matrix(response) # this is basically the first column in Table 3.2: # print('first column: beta estimates') print(betaHat,digits=2) # make predictions based on these estimated beta coefficients: # yhat = Dp %*% betaHat # estimate the variance: # sigmaHat = sum( ( response - yhat )^2 ) / ( nrow - p - 1 ) # calculate the covariance of betaHat: # covarBetaHat = sigmaHat * ginv( t(Dp) %*% Dp ) # calulate the standard deviations of betahat: # stdBetaHat = sqrt(diag(covarBetaHat)) # this is basically the second column in Table 3.2: # print('second column: beta standard errors') print( as.matrix(stdBetaHat), digits=2 ) # compute the z-scores: # z = betaHat / stdBetaHat # this is basically the third column in Table 3.2: # print('third column: beta z-scores') print( z, digits=2 ) # display the results we get : F = data.frame( Term=c("Intercept",names(XTraining)[1:p]), Coefficients=betaHat, Std_Error=stdBetaHat, Z_Score=z ) library(xtable) xtable( F, caption="Table of OLS coefficients for the SPAM data set", digits=2 ) # Run this full linear model on the Testing data so that we can fill in the two # lower spots in the "LS" column in Table 3.2 # pdt = cbind( matrix(1,dim(XTesting)[1],1), as.matrix( XTesting[,1:p] ) ) %*% betaHat responseTest = XTesting[,p+1] NTest = length(responseTest) mErr = mean( (responseTest - pdt)^2 ) print( mErr ) sErr = sqrt( var( (responseTest - pdt)^2 )/NTest ) print( sErr )