function [beta,F,prob] = TVLogDis(x,y,beta0,lambda,maxit) % %function [beta,F,prob] = TVLogDis(x,y,beta0,lambda,maxit) %Total-Variation Regularized Logistic Discrimination (two groups) % %INPUT % % x: Explanatory variables, INCLUDING intercept (n x p) % y: 0-1 group variables (n x 1) % beta0: Initial parameter estimators, usually zero (p x 1) % lambda: Penalization parameter (positive scalar) % maxit: Maximum number of iterations (scalar) % %OUTPUT % % beta: LDTV parameter estimators (p x 1) % F: Negative log-likelihood function (scalar) % prob: Estimated group probabilities (n x 1) % logistic = @(t)(1./(1+exp(-t))); [n,p] = size(x); e=ones(p,1); %Reparametrization S=spdiags([-e e],-1:0,p,p); S(2,1)=0; S=full(S); Sinv=inv(S); xSinv=x*Sinv; gam=S*beta0; %The reparametrization does not change the prior probabilities, %so we can use beta0 to calculate them p0 = logistic(x*beta0); %Compute initial values of loglikelihood function,its gradient and the %initial direction F = sum(-y.*log(max(p0,eps))-(1-y).*log(max(1-p0,eps)))/n+lambda*abs(gam(2:p))'*e(2:p); gradF=zeros(p,2); gradF(:,1)=xSinv'*(p0-y)/n+lambda*[0;sign(gam(2:p))]; pk=-gradF(:,1); err=1; count=0; %Initialize parameters for linesearch alphastart=0.3; c1=0.0001; c2=0.4; while count < maxit && err > 1e-6 count2=0; Fmin=F; %linesearch alpha0=alphastart; gradF_pk=gradF(:,1)'*pk; g=gam+alpha0*pk; p0 = logistic(xSinv*g); gradF(:,2)=xSinv'*(p0-y)/n+lambda*sign(g); F1(1)=sum(-y.*log(max(p0,eps))-(1-y).*log(max(1-p0,eps)))/n+lambda*abs(g(2:p))'*e(2:p); if F1(1) > (Fmin+c1*alpha0*gradF_pk) || abs(gradF(:,2)'*pk) > c2*abs(gradF_pk) %quadratic interpolation alpha1=-(gradF_pk*alpha0^2)/(2*(F1(1)-Fmin-gradF_pk*alpha0)); g=gam+alpha1*pk; p0 = logistic(xSinv*g); gradF(:,2)=xSinv'*(p0-y)/n+lambda*sign(g); F1(2)=sum(-y.*log(max(p0,eps))-(1-y).*log(max(1-p0,eps)))/n+lambda*abs(g(2:p))'*e(2:p); while (F1(2) > (Fmin+c1*alpha1*gradF_pk)) || abs(gradF(:,2)'*pk) > c2*abs(gradF_pk) %cubic interpolation count2=count2+1; A=[alpha0^2 -alpha1^2;-alpha0^3 alpha1^3]; B=[F1(2)-Fmin-gradF_pk*alpha1;F1(1)-Fmin-gradF_pk*alpha0]; a=1/(alpha0^2*alpha1^2*(alpha1-alpha0))*A*B; alpha0=alpha1; if (a(2)^2-3*a(1)*gradF_pk) < eps || count2 > 50 g=gam; p0 = logistic(xSinv*g); gradF(:,2)=xSinv'*(p0-y)/n+lambda*sign(g); F1(2)=-sum(y.*log(max(p0,eps))+(1-y).*log(max(1-p0,eps)))/n+lambda*abs(g(2:p))'*e(2:p); disp('There is no steplength that fulfills the strong WC'); break; end alpha1=(-a(2)+sqrt(a(2)^2-3*a(1)*gradF_pk))/(3*a(1)); if abs(alpha0-alpha1) < eps || alpha1 < eps g=gam; p0 = logistic(xSinv*g); gradF(:,2)=xSinv'*(p0-y)/n+lambda*sign(g); F1(2)=-sum(y.*log(max(p0,eps))+(1-y).*log(max(1-p0,eps)))/n+lambda*abs(g(2:p))'*e(2:p); disp('Improvement too small'); break; end g=gam+alpha1*pk; p0 = logistic(xSinv*g); F1(1)=F1(2); gradF(:,2)=xSinv'*(p0-y)/n+lambda*sign(g); F1(2)=-sum(y.*log(max(p0,eps))+(1-y).*log(max(1-p0,eps)))/n+lambda*abs(g(2:p))'*e(2:p); end end Fmin=F1(2); %/linesearch gradF(:,2)=xSinv'*(p0-y)/n+lambda*[0;sign(g(2:p))]; if rem(count,100)==0 alphastart=min(1,alphastart*10); end %Test for orthogonality %Restart if two consecutive directions are far from orthogonal test=abs(gradF(:,2)'*gradF(:,1))/(gradF(:,2)'*gradF(:,2)); if test > 1 betaFR=0; else betaFR=gradF(:,2)'*gradF(:,2)/(gradF(:,1)'*gradF(:,1)); end if count > 0 err=norm(gam-g)/norm(gam); else err=norm(gam-g); end gam=g; pk=-gradF(:,2)+betaFR*pk; gradF(:,1)=gradF(:,2); count=count+1; if mod(count,50)==0 disp([num2str(count) ' iterations']) end F = Fmin; end disp(['Total: ' num2str(count) ' iterations']) beta=Sinv*gam; prob = logistic(x*beta);