% % Written by: % -- % John L. Weatherwax 2007-07-01 % % email: wax@alum.mit.edu % % Please send comments and especially bug reports to the % above email address. % % EM algo: Epage 57 % Problem: Epage % %----- close all; clc; clear; rand('seed',0); randn('seed',0); % Generate the reqested data: % maxNSamples = 500; nSamples = 0; X = []; while( nSamples < maxNSamples ) % draw two samples from second density X = [X; mvnrnd( 3.0, 0.1, 2 ) ]; % draw one sample from first density X = [X; mvnrnd( 1.0, 0.1, 1 ) ]; % draw one sample from last density X = [X; mvnrnd( 2.0, 0.2, 1 ) ]; nSamples = nSamples + 4; end mu_true = [ 1. , 3. , 2. ].'; s2_true = [ 0.1, 0.1, 0.2 ].'; p_true = [ 1/4, 2/4, 1/4 ].'; % Initialize the EM algorithm: % N = size(X,1); J = 3; % initial number of assumed clusters mu_j = rand(J,1); % initial means sigma2_j = rand(J,1).^2; % initial variances P_j = (1/J) * ones(J,1); % initial values of the priors (unifom) % Iterate the EM algorithm: % abs_tol = 1.e-6; ii=1; while( 1 ) % Implement Eq 2.87 = P(j|x_k;\Theta(t)) and Eq 2.88 = p(x_k;\Theta(t)) % P_j_xk = zeros(N,J); for jj=1:J P_j_xk(:,jj) = mvnpdf( X, mu_j(jj), sigma2_j(jj) ) * P_j(jj); end p_xk = sum(P_j_xk,2); P_j_xk = P_j_xk ./ repmat( p_xk, 1, 3 ); % M-step: % mu_j_new = zeros(J,1); sigma2_j_new = zeros(J,1); P_j_new = zeros(J,1); for jj=1:J denom = sum( P_j_xk(:,jj) ); numer = sum( P_j_xk(:,jj) .* X ); mu_j_new(jj) = numer/denom; % update the cluster means mu_j_rep = mu_j_new(jj) * ones(N,1); numer = sum( P_j_xk(:,jj) .* ((X-mu_j_rep).^2) ); sigma2_j_new(jj) = numer/denom; % update the cluster variances P_j_new(jj) = mean( P_j_xk(:,jj) ); end mud = norm(mu_j-mu_j_new); sid = norm(sigma2_j-sigma2_j_new); pd = norm(P_j-P_j_new); ii = ii+1; if( mod(ii,10)==0 ) fprintf('norm(mu diff)= %10.6f; norm(sigma2 diff)= %10.6f; norm(P diff)= %10.6f\n',mud,sid,pd); end if( ( mud10000 ) fprintf('\nnorm(mu diff)= %10.6f; norm(sigma2 diff)= %10.6f; norm(P diff)= %10.6f\n',mud,sid,pd); break; end mu_j = mu_j_new; sigma2_j = sigma2_j_new; P_j = P_j_new; end % Find the permutation of our output that best matches the truth % if( 0 ) error( ' (THIS DOES NOT WORK) ' ); value=+1.e9; best_shift_jj = 0; for jj=0:(J-1) n1 = norm( circshift( mu_j, jj ) - mu_true ); n2 = norm( circshift( sigma2_j, jj ) - s2_true ); n3 = norm( circshift( P_j, jj ) - p_true ); if( n1+n2+n3 < value ) value = n1+n2+n3; best_shift_jj = jj; end end if( best_shift_jj ~= 0 ) mu_j = circshift( mu_j, best_shift_jj ); sigma2_j = circshift( sigma2_j, best_shift_jj ); P_j = circshift( P_j, best_shift_jj ); end end % permute the results so that they match the truth values orderings: % p = [ 3 2 1 ]; mu_j = mu_j( p ); sigma2_j = sigma2_j( p ); P_j = P_j( p ); fprintf('mu_true = '); fprintf('%10.6f; ', mu_true.' ); fprintf('\n'); fprintf('mu_j = '); fprintf('%10.6f; ', mu_j.' ); fprintf('\n'); fprintf('\n'); fprintf('s2_true = '); fprintf('%10.6f; ', s2_true.' ); fprintf('\n'); fprintf('sigma2_j= '); fprintf('%10.6f; ', sigma2_j.' ); fprintf('\n'); fprintf('\n'); fprintf('p_true = '); fprintf('%10.6f; ', p_true.' ); fprintf('\n'); fprintf('P_j = '); fprintf('%10.6f; ', P_j.' ); fprintf('\n'); fprintf('\n');