function W=PCF_ica(x,varargin)

% ICA Model -  x=As, where s have independent components, A is called mixing matrix;
% Fast-CHFICA : prewhitened CHFICA with incomplete Cholesky decomposition;
%              Returns a demixing matrix W such that s=W*x 's components are as close 
%              as possible to mutually independent;
%	           It first whitens the data and then minimizes
%              the given contrast function over orthogonal matrices.
% x          : observed mixtures;
% Nrestart   : number of initializations;
% W0         : demixing matric initialization

% Version 1.0, CopyRight: Aiyou Chen;
% Jan. 2004 in UC Village at Albany, CA

% Acknowledgement: 
% Some subroutines like Gold_Search(), Bracket_IN() are motivated 
% by Francis Bach's KERNEL_ICA algorithm;

% default values
verbose=0;
Nrestart=1;

[m,N]=size(x);
ncomp=m;
   
% first centers and scales data
%  if (verbose), fprintf('\nStart Fast-CHFICA \nwhiten ...\n'); end
  xc=x-repmat(mean(x,2),1,N);  % centers data
  covmat=(xc*xc')/N;

  sqcovmat=sqrtm(covmat);
  invsqcovmat=inv(sqcovmat);
  xc=invsqcovmat*xc;           % scales data
%  if (verbose), fprintf('unmixing ...\ndone\n\n'); end
  
  % making initial guess orthogonal (for a full matrix)
  W0=rand_orth(m);
   
% optional values
if (rem(length(varargin),2)==1)
   error('Optional parameters should always go by pairs');
else
   for i=1:2:(length(varargin)-1)
      switch varargin{i}
      case 'W0'
         W0= varargin{i+1};
      case 'Nrestart'
         Nrestart= varargin{i+1};
      end
   end
end

[U,S,V]=svd(W0*sqcovmat);
W0=U*V';

[J,W]= globalopt(xc,W0,Nrestart);

W=W*invsqcovmat;
%for i=1:m, W(i,:)=W(i,:)/norm(W(i,:)); end
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Jopt,Wopt,OptDetails] = globalopt(x,W,Nrestart);
% GLOBALOPT  -  global minimization of empirical characteristic contrast function with 
%               random restarts the data are assumed whitened (i.e. with identity covariance
%               matrix). The output is such that Wopt*x are the independent sources.
%
% x          - data (mixtures)
% w          - orthogonal matric, starting point of the search
%                  tolW      : precision in amari distance in estimated demixing matrix;
%                  tolJ      : precision in objective function;
%                  maxit     : maximum number of iterations;
%                  verbose   : 1 if verbose required.
%                  Nrestart  : number of restarts;

%initialization 
%tolW=1e-3;
%tolJ=1e-3;
%maxit=10;
verbose=0;
Jaccept=1e-3;
if (nargout>2), details=1; else details=0; end
iter = 1;
m=size(W,1);
Wmin=W;
Jmin= spectchf(W*x);
totalneval=1;
if nargin<4, Nrestart=1; end
if (verbose), fprintf('Starting optimization, with %d restarts\n',Nrestart); end

%starting restarts loop
while (iter<=Nrestart) & (Jmin>Jaccept)
   if (verbose) fprintf('\nStarting a new local search, #%d\n',iter); end
   
   if (iter>1) 
      % selects a new random restart as far as possible as current minimums
      NWs=m*m*4;
      Wrandrest=cell(1,NWs);
      distances=zeros(1,2*iter-2);
      maxmindist=0;
      W=[];
      for i=1:NWs
         Wrandrest=rand_orth(m);
         for j=1:iter-1
            distances(2*j-1)=amari(Wrandrest,OptDetails.Wloc{j});
            distances(2*j  )=amari(Wrandrest,OptDetails.Wstart{j});
         end
         mindist=min(distances);
         if (mindist>maxmindist) maxmindist=mindist; W=Wrandrest; end
      end
   end
   
% performs local search to local minimum (requires a non transposed matrix)

   [Jloc,Wloc,detailsloc] = localopt(x,W);

   if (iter==1)
      Wmin=Wloc;
      Jmin=Jloc;
   else
      if (Jloc<Jmin), Wmin=Wloc; Jmin=Jloc; end
   end
   
   totalneval=totalneval+detailsloc.totalneval;
   OptDetails.Wloc{iter}=Wloc;
   OptDetails.Wstart{iter}=W;
   OptDetails.Jloc(iter)=Jloc;
   iter=iter+1;
end
Jopt= Jmin;
Wopt=Wmin; 
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Jopt,Wopt,OptDetails] = localopt(x,W)
% LOCAL-OPT  -  Conjugate gradient method for finding a minima in the
%               Stiefel manifold of orthogonal matrices Wopt such 
%               that Wopt*x are the independent sources.
% W          - initial value
% x          - data (mixtures);
% w          - orthogonal matric, starting point of the search;
% tolW    :    precision in amari distance in est. demixing matrix;
% tolJ    :    precision in objective function;
% maxit   :    maximum number of iterations;
% verbose :    1 if verbose required.

% initializations
if (nargout>2), details=1; else details=0; end
m=size(W,1);
N=size(x,2);
tolW=1e-2;
tolJ=0.01; % linear form in 1/(N*m)
maxit=10*m;

verbose=0;
tmin=1;
iter = 0;
errW = tolW*2;
errJ = tolJ*2;
fret = spectchf(W*x);
totalneval=1;
transgradJ=0;
k=0;

% starting minimization
%while (((errW > tolW)|(errJ > tolJ*fret)) & (iter < maxit)  )
while (((errW > tolW)|(errJ > tolJ)) & (iter < maxit)  )
   Jold=fret;
   iter=iter+1;
   if (verbose), fprintf('iter %d, J=%.5e',iter,fret); end
   
   % calculate derivative
%   gradJ=mcgradobj(x,W,T);
   gradJ=spectgradchf(x,W);
   iterneval=m*(m-1)/2+1;
   normgradJ=sqrt(.5*trace(gradJ'*gradJ));
   
   dirSearch=gradJ-W*gradJ'*W;
   normdirSearch=sqrt(.5*trace(dirSearch'*dirSearch));
   
   % bracketing the minimum along the geodesic and performs golden search
   [ ax, bx, cx,fax,fbx,fcx,neval] = bracket_min(W,dirSearch,x,0,tmin,Jold);
   iterneval=iterneval+neval;
   goldTol=max(abs([tolW/normdirSearch, mean([ ax, bx, cx])/10]));
   [tmin, Jmin,neval] = golden_search(W,dirSearch,x,ax, bx, cx,goldTol,20);
   iterneval=iterneval+neval;
   totalneval=totalneval+iterneval;
   oldtransgradJ=transgradJ;
   Wnew=stiefel_geod(W',dirSearch',tmin);  
   oldnormgradJ=sqrt(.5*trace(gradJ'*gradJ));
   
   errW=amari(W,Wnew);
   errJ=Jold/Jmin-1;
   if (verbose)
      fprintf(', dJ= %.1e',errJ);
      fprintf(',errW=%.1e,dW= %.3f, neval=%d\n',errW,tmin*normdirSearch,iterneval);
   end
   
   if (details)
      % debugging details
      OptDetails.Ws{iter}=W;
%      OptDetails.Js(iter)=J0;
      OptDetails.numeval(iter)=totalneval;
      OptDetails.numgoldit(iter)=neval;
      OptDetails.ts(iter)=tmin;
      OptDetails.normdirsearch(iter)=normdirSearch;
      OptDetails.normgrad(iter)=oldnormgradJ;
      OptDetails.amaridist(iter)=errW;
      OptDetails.geoddist(iter)=tmin*normdirSearch;
   end
   
   if (errJ>0) 
      W=Wnew;
      fret=Jmin;
   end
   W;
end

Jopt= fret;
Wopt=W;

if (details)
   OptDetails.totalneval=totalneval;
end
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Wt,Ht]=stiefel_geod(W,H,t)

% STIEFEL_GEOD - parameterizes a geodesic along a Stiefel manifold

% W  - origin of the geodesic
% H  - tangent vector
% Ht - tangent vector at "arrival"
% Alan Edelman, Tomas Arias, Steven Smith (1999)

if nargin <3, t=1; end
A=W'*H; A=(A-A')/2;
MN=expm(t*A);
Wt=W*MN;
if nargout > 1, Ht=H*MN; end
Wt=Wt';
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [xmin,fmin,neval] = golden_search(W,dirT,x,ax,bx,cx,tol,maxiter)

% GOLDEN_SEARCH - Minimize contrast function along a geodesic of the Stiefel
%                 manifold using golden section search.
% W              - initial value
% x              - mixed components
% dirT           - direction of the geodesic
% ax,bx,cx       - three abcissas such that the minimum is bracketed between ax and cx,
%                  as given by bracket_mini.m
% tol            - relative accuracy of the search
% maxit          - maximum number of iterations
% neval          - outputs the number of evaluation of the contrast function

neval=0;
% golden ratios
C = (3-sqrt(5))/2;
R = 1-C;

x0 = ax;
x3 = cx;

% gets the smaller segment
if (abs(cx-bx) > abs(bx-ax)),
   x1 = bx;
   x2 = bx + C*(cx-bx);
else
   x2 = bx;
   x1 = bx - C*(bx-ax);
end
Wtemp=stiefel_geod(W',dirT',x1);
f1=spectchf(Wtemp*x);
neval=neval+1;
Wtemp=stiefel_geod(W',dirT',x2);
f2=spectchf(Wtemp*x);
neval=neval+1;
k = 1;

% starts iterations
while ((abs(x3-x0) > tol) & (k<maxiter)), 
   if f2 < f1,
      x0 = x1;
      x1 = x2;
      x2 = R*x1 + C*x3;   
      f1 = f2;
      Wtemp=stiefel_geod(W',dirT',x2);
      f2=spectchf(Wtemp*x);
      neval=neval+1;
   else
      x3 = x2;
      x2 = x1;
      x1 = R*x2 + C*x0;  
      f2 = f1;
      Wtemp=stiefel_geod(W',dirT',x1);
      f1=spectchf(Wtemp*x);
      neval=neval+1;
   end
   k = k+1;
end

% best of the two possible
if f1 < f2,
   xmin = x1;
   fmin = f1;
else
   xmin = x2;
   fmin = f2;
end

return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [ ax, bx, cx,fax,fbx,fcx,neval] = bracket_min(W,dirT,x,ax, bx,fax)

% BRACKET_MIN - Brackets a minimum by searching in both directions along a geodesic in
%               the Stiefel manifold
% W              - initial value
% x              - mixed components
% dirT           - direction of the geodesic
% ax,bx          - Initial guesses
% tol            - relative accuracy of the search
% maxit          - maximum number of iterations
% neval          - outputs the number of evaluation of the contrast function

neval=0;
GOLD=1.618034;
TINY=1e-10;
GLIMIT=100;
Wtemp=stiefel_geod(W',dirT',bx);
fbx=spectchf(Wtemp*x);

neval=neval+1;

if (fbx > fax)   
   temp=ax;
   ax=bx;
   bx=temp;
   temp=fax;
   fax=fbx;
   fbx=temp;
end

cx=(bx)+GOLD*(bx-ax);
Wtemp=stiefel_geod(W',dirT',cx);
fcx=spectchf(Wtemp*x);

neval=neval+1;

while (fbx > fcx) 
   
   r=(bx-ax)*(fbx-fcx);
   q=(bx-cx)*(fbx-fax);
   u=(bx)-((bx-cx)*q-(bx-ax)*r)/(2.0*max([abs(q-r),TINY])*sign(q-r));
   ulim=(bx)+GLIMIT*(cx-bx);
   if ((bx-u)*(u-cx) > 0.0)
      Wtemp=stiefel_geod(W',dirT',u);
      fux=spectchf(Wtemp*x);
      
      neval=neval+1;
      
      if (fux < fcx) 
         ax=(bx);
         bx=u;
         fax=(fbx);
         fbx=fux;
         return;
      else 
         if (fux > fbx) 
            cx=u;
            fcx=fux;
            return;
         end
      end
      
      u=(cx)+GOLD*(cx-bx);
      Wtemp=stiefel_geod(W',dirT',u);
      fux=spectchf(Wtemp*x);
      neval=neval+1;
      
   else 
      if ((cx-u)*(u-ulim) > 0.0) 
         Wtemp=stiefel_geod(W',dirT',u);
         fux=spectchf(Wtemp*x);
         neval=neval+1;
         
         if (fux < fcx) 
            bx=cx;
            cx=u;
            u=cx+GOLD*(cx-bx);
            
            fbx=fcx;
            fcx=fux;
            Wtemp=stiefel_geod(W',dirT',u);
            fux=spectchf(Wtemp*x);
         end
      else 
         if ((u-ulim)*(ulim-cx) >= 0.0) 
            
            u=ulim;
            Wtemp=stiefel_geod(W',dirT',u);
            fux=spectchf(Wtemp*x);
            neval=neval+1;
            
         else 
            u=(cx)+GOLD*(cx-bx);
            Wtemp=stiefel_geod(W',dirT',u);
            fux=spectchf(Wtemp*x);
            neval=neval+1;
            
         end
      end
   end
   
   ax=bx;
   bx=cx;
   cx=u;
   
   fax=fbx;
   fbx=fcx;
   fcx=fux;
   
end
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function grad=spectgradchf(X,W);

% Calculate the Ecludean gradient of the CHFICA contrast function.
% Make use of the incomplete Cholesky decomposition to approximate the Gram
% matrix generated by the Gaussian kernel.
% We don't need the first term.

tol=0.01; sig=1;
[m,n]=size(X);
H=zeros(n,m);     % to store elements for the third term;
Z=zeros(1,m);     % to store elements for the second term;
h=[]; hh=[];                % temporary vector for rows' summation;
HH=zeros(n,m*m);            % to store elements for derivatives of the third terms;
L=zeros(1,n);
s=zeros(1,n);
grad=zeros(m);

for II=1:m,
    s=W(II,:)*X;                        % s=S(II,:)=W(II,:)*X, the kth recovered sources;
    [G,Pvec]=chol_gauss(s,sig,tol);
    [a,Pvec]=sort(Pvec); R=G(Pvec,:);
    % sum of all rows;   
    h=sum(R,1);
    Z(II)=(h*h')/n^2;     
    H(:,II)=R*h';              
    for JJ=1:m,
        L=s.*X(JJ,:);
        
        hh=R*h';
        HH(:,m*(II-1)+JJ)=hh.*L';
        hh=L*R;
        HH(:,m*(II-1)+JJ)=HH(:,m*(II-1)+JJ)+R*hh';        

        hh=X(JJ,:)*R;
        hh=R*hh';
        HH(:,m*(II-1)+JJ)=HH(:,m*(II-1)+JJ)-hh.*s';

        hh=s*R;
        hh=R*hh';
        HH(:,m*(II-1)+JJ)=HH(:,m*(II-1)+JJ)-hh.*X(JJ,:)';

%        hh=R; 
%        for i=1:size(R,2), hh(:,i)=hh(:,i).*L'; end
%        HH(:,m*(II-1)+JJ)=(hh)*h'+R*(L*R)';
%        hh=R;
%        for i=1:size(R,2), hh(:,i)=hh(:,i).*s'; end        
%        HH(:,m*(II-1)+JJ)=HH(:,m*(II-1)+JJ)-hh*(X(JJ,:)*R)';

%        hh=R;
%        for i=1:size(R,2), hh(:,i)=hh(:,i).*X(JJ,:)'; end        
%        HH(:,m*(II-1)+JJ)=HH(:,m*(II-1)+JJ)-hh*(s*R)';
    end    
end

z=prod(Z);
for II=1:m,
    z2=prod(Z([1:(II-1),(II+1):m]));
    z3=prod(H(:,[1:(II-1),(II+1):m])/n,2);
    for JJ=1:m,          
        grad(II,JJ)=-z2*sum(HH(:,m*(II-1)+JJ)/n)/n+2*sum(HH(:,m*(II-1)+JJ)/n.*z3)/n;      
    end  
end        
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function z=spectchf(s);

% SPECTCHF approximates the CHFICA contrast function by using incomplete
% Cholesky decomposition. 
% We don't need the first term.

tol=0.01; sig=1;
[m,n]=size(s);
H=zeros(n,m);
z2=1;
h=[];

for k=1:m,
    [G,Pvec]=chol_gauss(s(k,:),sig,tol);
    [a,Pvec]=sort(Pvec); R=G(Pvec,:);
    % sum of all rows;   
    h=sum(R,1);
    
    % second term;
    z2=z2*(h*h')/n^2;

    % to calculate third term: store partial sums;    
    H(:,k)=R*h';          
end
h=prod(H/n,2);
z=1+z2-2*sum(h)/n;
return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function W=rand_orth(n,m);
% RAND_ORTH - random matrix with orthogonal columns

if (nargin<2)   m=n; end

W=rand(m)-.5;
[W,myhoney]=qr(W);
W=W(1:m,1:n);

return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function w=rndmat(m,n)

%RNDMAT - generate m*m matrix with each row of norm 1;

if nargin<2, n=m; end

w=unifrnd(0,1,m,n);
for i=1:m;   w(i,:)=w(i,:)/norm(w(i,:)); end;

return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function d=amari(V,W)

% AMARI - amari distance: measure invariance to permutation and scaling of the columns of V and W;

m=size(V,1);
% normalize each row of V and W before calculate amari-distance;
for i=1:m, V(i,:)=V(i,:)/norm(V(i,:)); W(i,:)=W(i,:)/norm(W(i,:)); end

Per=inv(V')*W';
Perf=[sum(abs(Per))./max(abs(Per))-1,sum(abs(Per'))./max(abs(Per'))-1];

d=mean(Perf);

return;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function nw=rerow(w)

% normalize each row;
     m=size(w,1);
     for i=1:m,
	     w(i,:)=w(i,:)/norm(w(i,:));
     end;
     
% reorder rows;
     for i=1:m-1,
        mx=abs(w(i,i)); kmx=i;
        for k=i:m,
           if abs(w(k,i))>mx, kmx=k; mx=abs(w(k,i)); end
        end;
        if kmx~=i, y=w(i,:); w(i,:)=w(kmx,:); w(kmx,:)=y; end
        if w(i,i)<0, w(i,:)=-w(i,:); end
     end;
     if w(m,m)<0, w(m,:)=-w(m,:);end
     nw=w;

     return;