using LinearAlgebra, Distributions, StatsBase, Plots,MLDatasets
using Statistics
using Nemo, Images
using ToeplitzMatrices
using Random
using Kronecker
using StatPlots
using SpecialFunctions: erfc
using SpecialFunctions: erfcinv
using NLsolve
using MAT
using NLsolve
using Optim
using MultivariateStats
using FFTW
include("utility.jl")
# data_type="synthetic";
data_type="synthetic";
n_training_sample=500;
if data_type=="synthetic"
    p=256;
    ns=[500,300,100,150];
    n_test=convert.(Int,10000*ones(4,1));
    M=zeros(p,4);
    M[:,1]=[ 1; 1;0;0;zeros(p-4,1)];
    M[:,2]=[-1;-1;0;0;zeros(p-4,1)];
    orth=[ 0; 0;1;1;zeros(p-4,1)];
    β=0.5;
    M[:,3]=β*M[:,1]+sqrt(1-β^2)*M[:,2];
    M[:,4]=-M[:,3];
    k=2;
    Σ=zeros(p,p,4);
    Σ[:,:,1]=Matrix{Float64}(I, p, p);
    Σ[:,:,2]=Matrix{Float64}(I, p, p);
    Σ[:,:,3]=Matrix{Float64}(I, p, p);
    Σ[:,:,4]=Matrix{Float64}(I, p, p);
    n=sum(ns);
    m=2;
    # J=zeros(n,m*k);
    # for i=1:m*k
    #     J(sum(ns(1:i-1))+1:sum(ns(1:i)),i)=ones(ns(i),1);
    # end
    k=2;m=2;
    Sfts,Slabels,Tfts,Tlabels,X_test,y_test,X_test1,y_test1=generate_data(p,ns,k,m,β,M,Σ,data_type,n_test)
    λ=1;γ=[0.1;1];
    wanted=10 .^(range(-4,stop=-1,length=10));
    #wanted=logspace(-6,-2,10);
elseif data_type=="real"
         ns=[500;500;50;50];
         λ=10;γ=[0.1;1];
         #p=550;
         p=100;
         fileTrain=matopen("/Users/tiomokomalik/Downloads/JMLR_code/data/x_train.mat");
         fileTrainLabel=matopen("/Users/tiomokomalik/Downloads/JMLR_code/data/y_train.mat");
         fileTest=matopen("/Users/tiomokomalik/Downloads/JMLR_code/data/x_test.mat");
         fileTestLabel=matopen("/Users/tiomokomalik/Downloads/JMLR_code/data/y_test.mat");
         Sfts_init=read(fileTrain,"x_train");
         Slabels_init=read(fileTrainLabel,"y_train");
         Test_init=read(fileTest,"x_test");
         TestLabels_init=read(fileTestLabel,"y_test");
         # M = fit(Whitening, Sfts_init');
         M = fit(PCA, Sfts_init'; maxoutdim=100);
         Sfts_init = transform(M, Sfts_init')
         Test_init = transform(M, Test_init')'
         Sfts_class0=Sfts_init[:,vec(Slabels_init.==1)]';
         Sfts_class2=Sfts_init[:,vec(Slabels_init.==2)]';
         Sfts_class3=Sfts_init[:,vec(Slabels_init.==3)]';
         Sfts11=Sfts_class0[1:ns[1],:]';
         Sfts12=Sfts_class2[1:ns[2],:]';
         Tfts21=Sfts_class0[ns[1]+1:ns[1]+ns[3],:]';
         Tfts22=Sfts_class3[1:ns[4],:]';
         Sfts=zscore([Sfts11 Sfts12]);Tfts=zscore([Tfts21 Tfts22]);
         Slabels=[ones(ns[1],1);-ones(ns[2],1)];
         Tlabels=[ones(ns[3],1);-ones(ns[4],1)];
         Test_class0=Test_init[vec(TestLabels_init.==1),:]';
         # Test_class0=Test_class0[:,1:100];
         Test_class3=Test_init[vec(TestLabels_init.==3),:]';
         # Test_class3=Test_class3[:,1:100];
         X_test1=zscore([Test_class0 Test_class3]);
         n_test=[ns[1:2];size(Test_class0)[2];size(Test_class3)[2]];
         X_test=Sfts;y_test=Slabels;
         wanted=10 .^(range(-5,stop=-1,length=10));
         # Slabels=Sla[Sla.!=0];
         # Sfts=Sfts[:,1:n_training_sample];
#         x_train(y_train==0,:)=[];y_train(y_train==0)=[];
#         x_train = x_train';
#         x_test = x_test';
#         x_train=[x_train(:,1:n_training_sample) x_train(:,2223+1:2223+n_training_sample) x_train(:,2223+5788+1:2223+5788+n_training_sample) x_train(:,2223+5788+641+1:2223+5788+641+n_training_sample)];
#         y_train=[y_train(1:n_training_sample) y_train(2223+1:2223+n_training_sample) y_train(2223+5788+1:2223+5788+n_training_sample) y_train(2223+5788+641+1:2223+5788+641+n_training_sample)];
#         %y_test=y_test(1:100);
#             %p=p_vec(h);
#         %jr=3000;
#         k=2;m=2;
#         %ns(1:m*k)=randi(jr,m*k,1);
#         %ns=floor([2.7*p 1.3*p 2.3*p 1.7*p]);
#
#
#
#         % nst=1000*ones(4,1);
#         X1=[];X2=[];M11=[];M22=[];
#         c1=[3 4 1 2];
#         for task=1:k-1
#             for i=1:m
#         %        X11{i,task}=Xsr{task}(:,ysr{task}==c1(i))';
#                 X11{i,task}=x_train(:,y_train==c1(i))';
#                 X1=[X1 X11{i,task}'];
#                 ns(m*(task-1)+i)=size(X11{i,task},1);
#                 M11=[M11 mean(X11{i,task})'];
#                 C(:,:,m*(task-1)+i)=(X11{i,task}-mean(X11{i,task},1))'*(X11{i,task}-mean(X11{i,task},1))/ns(m*(task-1)+i);
#             end
#         end
#         for i=1:m
#             X22{i}=x_train(:,y_train==c1(i+2))';
#
#             ns(i+m*(k-1))=size(X22{i},1);
#             X2=[X2 X22{i}(1:ns(i+m*(k-1)),:)'];
#             M22=[M22 mean(X22{i}(1:ns(i+m*(k-1)),:))'];
#             C(:,:,i+m*(k-1))=(X22{i}-mean(X22{i},1))'*(X22{i}-mean(X22{i},1))/ns(i+m*(k-1));
#         end
#          M=[M11 M22];
#          %ns=[500;500;500;500]';
#          nst=100*ones(4,1);
#         p=size(X11{1,1},2);
# %         for j=1:2
# %             X1 = [X1 M(:,j)+C(:,:,j)^(1/2)*randn(p,ns(j))];
# %         end
#          X_test1 =  M(:,1)+C(:,:,1)^(1/2)*randn(p,nst(1));
#          X_test2 =  M(:,2)+C(:,:,2)^(1/2)*randn(p,nst(2));
# %         X2=[];X_test=[];
# %         for j=1:2
# %             X2 = [X2 M(:,2*(k-1)+j)+C(:,:,2*(k-1)+j)^(1/2)*randn(p,ns(2*(k-1)+j))];
# %         end
# %         X_test3 = M(:,3)+C(:,:,3)^(1/2)*randn(p,nst(3));
# %         X_test4 =  M(:,4)+C(:,:,4)^(1/2)*randn(p,nst(4));
#
#         %M=[M11 M22];
#         X_test3=x_test(:,y_test==1);
#         X_test4=x_test(:,y_test==2);
#         %M_test=M;
#
#         n=sum(ns);
#         nte = size(x_test, 2);
#         nst(1)=size(X_test1,2);nst(2)=size(X_test2,2);nst(3)=size(X_test3,2);nst(4)=size(X_test4,2);
#
#         %k=length(ns);
#
#         J=zeros(n,m*k);
#         for i=1:m*k
#             J(sum(ns(1:i-1))+1:sum(ns(1:i)),i)=ones(ns(i),1);
#         end
#
#         %%%%%%%%%%%%%%%%%%%%%%%%%%% TEST
#         %n = size(x_test, 2);
#         test1 = [X_test1];
#         testm1 = [X_test2];
#         test2 = [X_test3];
#         testm2 = [X_test4];
#
#         n_test=sum(nst);
#         J_test=zeros(n_test,m*k);
#         for i=1:m*k
#             J_test(sum(nst(1:i-1))+1:sum(nst(1:i)),i)=ones(nst(i),1);
#         end
#         yc(1:ns(1))=1;yc(ns(1)+1:ns(1)+ns(2))=-1;
#         yc(ns(1)+ns(2)+1:ns(1)+ns(2)+ns(3))=1;yc(ns(1)+ns(2)+ns(3)+1:sum(ns))=-1;
#         param=zeros(m*k,1);
#         n1=ns(1)+ns(2);n2=ns(3)+ns(4);
#         lambda=1e2;gamma1=1000;gamma2=100;
#         k=2;
#         init=[100,1000,2];
#         X1=X1/sqrt(2*p);X2=X2/sqrt(2*p);
#         wanted=logspace(-4,-1,10);
end
axi1=zeros(length(wanted),1);
axi2=zeros(length(wanted),1);
axi1_opt=zeros(length(wanted),1);
axi2_opt=zeros(length(wanted),1);
axi1_th=zeros(length(wanted),1);
axi2_th=zeros(length(wanted),1);
axi1_th_opt=zeros(length(wanted),1);
axi2_th_opt=zeros(length(wanted),1);

for v=1:length(wanted)
    eta=wanted[v];
    gx_s,gx_t,μ_th,μ_emp,σ_th,σ_emp,error_source_emp,error_target_emp,error_source_th,error_target_th,
    gx_s_opt,gx_t_opt,μ_th_opt,μ_emp_opt,σ_th_opt,σ_emp_opt,error_source_emp_opt,error_target_emp_opt,error_source_th_opt,error_target_th_opt=
    RMTMTLLSSVM_PFA(Sfts,Slabels,Tfts,Tlabels,λ,γ,X_test,X_test1,ns,n_test,eta);
     # yop=[y_opt(1)*ones(1,ns(1)) y_opt(2)*ones(1,ns(2)) y_opt(3)*ones(1,ns(3)) y_opt(4)*ones(1,ns(4))];
#     [error_th,error_th_task_init,error_emp_task_init,alpha_task_init, b_task,score_mean_task_init,variance_task,score_emp_task_init,var_emp_task_init,y_opt_init] = MLSSVRTrain_th1_centered_fixed_pro(X1,X2, yop', gamma1,gamma2, lambda,M,C,J,test1,testm1,test2,testm2,ones(nst(1),1),-ones(nst(2),1),ones(nst(3),1),-ones(nst(4),1),ns,'task',wanted[v],'general');
# % yop=[y_opt(1)*ones(1,ns(1)) y_opt(2)*ones(1,ns(2)) y_opt(3)*ones(1,ns(3)) y_opt(1)*ones(1,ns(4))];
# % [error]=@(x) perf_theorique_fixed_pro_2(X1,X2, x(3), x(1),x(2),M,C,ns,wanted[v],'general');
# % options = struct('MaxFunctionEvaluations', 100);
# % [param,error_opt_th]=fmincon(error,init,[-1 0 0;0 -1 0;0 0 -1],[0;0;0],[],[],[],[],[], options);
# % lambda_opt=param(1);gamma_opt1=param(2); gamma_opt2=param(3);
# % [error_th,error_th_task,error_emp_task,alpha_task, b_task,score_mean_task,variance_task,score_emp_task,var_emp_task,y_opt] = MLSSVRTrain_th1_centered_fixed_pro(X1,X2, yop', gamma_opt1,gamma_opt2, lambda_opt,M,C,J,test1,testm1,test2,testm2,ones(nst(1),1),-ones(nst(2),1),ones(nst(3),1),-ones(nst(4),1),ns,'task',wanted[v],'general');
# % [error_th_init,error_th_task_init,error_emp_task_init,alpha_task_init, b_task,score_mean_task_init,variance_task,score_emp_task_init,var_emp_task_init,y_opt_init] = MLSSVRTrain_th1_centered_fixed_pro(X1,X2, yc', gamma1,gamma2, lambda,M,C,J,test1,testm1,test2,testm2,ones(nst(1),1),-ones(nst(2),1),ones(nst(3),1),-ones(nst(4),1),ns,'task',wanted[v],'general');

    axi1[v]=error_target_emp[1];
    axi2[v]=error_target_emp[2];
    axi1_th[v]=error_target_th[1];
    axi2_th[v]=error_target_th[2];
    axi1_opt[v]=error_target_emp_opt[1];
    axi2_opt[v]=error_target_emp_opt[2];
    axi1_th_opt[v]=error_target_th_opt[1];
    axi2_th_opt[v]=error_target_th_opt[2];
end
# %plot(axi2)
# %save('axes_update.mat', 'axi1', 'axi2', 'axi1_th', 'axi2_th', 'axi1_th_init', 'axi2_th_init', 'axi1_init', 'axi2_init')
# sprintf('(%12f, %12f)', [axi1', axi2']')
# sprintf('(%12f, %12f)', [axi1_th', axi2_th']')
# sprintf('(%12f, %12f)', [axi1_th_init', axi2_th_init']')
# sprintf('(%12f, %12f)', [axi1_init', axi2_init']')
plot(axi1+1e-6*ones(10,1),axi2+1e-6*ones(10,1),color=:red,label="Emp",xaxis=:log,yaxis=:log)
plot!(axi1_th+1e-6*ones(10,1),axi2_th+1e-6*ones(10,1),shape=:circle,color=:orange,label="Th",xaxis=:log,yaxis=:log)
plot!(axi1_opt+1e-6*ones(10,1),axi2_opt+1e-6*ones(10,1),color=:blue,label="Emp Opt",xaxis=:log,yaxis=:log)
plot!(axi1_th_opt+1e-6*ones(10,1),axi2_th_opt+1e-6*ones(10,1),color=:black,shape=:circle,label="Th Opt",xaxis=:log,yaxis=:log)
# sprintf('(%12f, %12f)', [1-axi2', axi1']')
# sprintf('(%12f, %12f)', [1-axi2_th', axi1_th']')
# sprintf('(%12f, %12f)', [1-axi2_th_init', axi1_th_init']')
# sprintf('(%12f, %12f)', [1-axi2_init', axi1_init']')
