import numpy as np

import torch
import pdb
import random
from random import sample
import os
from utils import Initialize_Seed,load_dataset,predict_regression,write_csv,predict_cls




from model_syn.ori_cca import Ori_CCA_fit_transform
#from model.ori_cca import Ori_CCA_fit_transform



from model_syn.MyDCCA_full_mnist_sample import MyDCCA_full_fit_transform




from model_syn.MyMVTCAE import MVTCAE_fit_transform

from numpy import random



def Concat_Method(multi_view):
    #pdb.set_trace()
    return np.concatenate(multi_view,axis=1)
    #mae, rmse, r2 = predict_regression(emb,label)





def main():
    Initialize_Seed(2)

   
    import argparse
    parser = argparse.ArgumentParser()
    import os 
    

    
   
    parser.add_argument("--a", type=float, default=200, help="noise loss")
    parser.add_argument("--method", type=str, default="dcca", help="method (concat,cca,dcca,dcca with noise)")

    parser.add_argument("--epoch", type=int, default=200, help="epochs")
    parser.add_argument("--num_views", type=int, default=2, help="num_of_views")
    parser.add_argument("--lr", type=float, default= 50*1e-4 *3 ,help="learning rate")
  
    parser.add_argument("--gpu", type=str, default="0", help="gpu")
    

    args = parser.parse_args()

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu  

    #pdb.set_trace()
    
    args.dataset = "./Syn_data/{}_2_4000_100_50_False_True_True_split_common_rate_experiment_data.npy".format(args.num_views)
    Experiment_Data = np.load(args.dataset,allow_pickle=True).item()
    train_cases =  Experiment_Data['train']
    test_cases = Experiment_Data['test']
    test_label =  Experiment_Data['test_label']
    tasks_num = len(test_label)
    cases_num = len(train_cases)

    
    
 
    Task_AVG_Performance = {}
    Task_AVG_Performance['God_method'] = np.zeros((tasks_num,cases_num))
 
    
   
    Task_AVG_Performance['Concat_method'] =  np.zeros((tasks_num,cases_num))
    
    Task_AVG_Performance['average view score'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average nesum in feature'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average cor in feature'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average reconstruction loss to input'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average nesum in DNNs'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average cor in DNNs'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average denoising loss'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average noise score'] =  np.zeros((tasks_num,cases_num))
 
    fail = []
    #pdb.set_trace()
    for j in range(cases_num):
        
        # if j!=0:
        #     continue
        train = train_cases[j]
        test = test_cases[j]

        no_concat = False
        try:
           
            if args.method=='dcca_with_noise':

                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='cca',lr=args.lr,a = args.a) #150 100

            elif args.method=='dcca':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',lr=args.lr) #150 100
            elif args.method=='linear_cca':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',lr=args.lr,linear=True) #150 100
            elif args.method=='linear_gcca':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',lr=args.lr,linear=True) 
            elif args.method=='dccae':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr) #150 100
            elif args.method=='dcca_private':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,private=True,lr=args.lr) #150 100
            # elif args.method=='dccae_noise_1':
            #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr,noise_type='noise_1') #150 100
            elif args.method=='dgcca':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',lr=args.lr) #150 100
            elif args.method=='dgcca_with_noise':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='gcca',lr=args.lr,a = args.a) #150 100
            elif args.method=='dgccae':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,lr=args.lr) #150 100
            elif args.method=='dgcca_private':
                case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,private=True,lr=args.lr) #150 100
            # elif args.method=='dccae_mma':
            #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr,regular=True) #150 100
            # elif args.method=='dgccae_decov':
            #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,lr=args.lr,regular=True) #150 100
                
            elif args.method=='cca':
                case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='cca')
                case1 = [case1,0,0]
                
            elif args.method=='kcca':
                case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='kcca')
                case1 = case1
                
            elif args.method=='prcca':
                case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='prcca')
                case1 = case1

            elif args.method=='mvtcae':
                case1 = MVTCAE_fit_transform(train=train,test=test,epochs=args.epoch) #30
                case1 = case1
                no_concat = True
           
                #no_concat = True
            elif args.method=='concat':
                case1 = test
                
                
            else:
                print('bad method')
                print(args.method)
                exit(0)
        except Exception as e:
              
            print(e)
            exit(0)
            
        case = case1
        # if record_noise:
        #     case,f_score,noise_score = case
        # else:
        #     case,f_score,_ = case
        #     noise_score = 0
        
    
        for i in range(tasks_num):
          
            f1 = 0
            Task_AVG_Performance['God_method'][i][j] = f1

            if not no_concat:
              
                
                 
                try:
                    fea,f_score,loss_1,loss_2,loss_3,loss_4,loss_5,loss_6,loss_7  = case 
                except:
                    f1= predict_regression(Concat_Method(case),test_label[i])  # [view1,view2]
                    Task_AVG_Performance['Concat_method'][i][j] = f1
                    continue
                f1= predict_regression(Concat_Method(fea),test_label[i])  # [view1,view2]
              
                Task_AVG_Performance['Concat_method'][i][j] = f1

                Task_AVG_Performance['average view score'][i][j] = f_score

                Task_AVG_Performance['average nesum in feature'][i][j] = loss_3

                Task_AVG_Performance['average cor in feature'][i][j] = loss_2

                Task_AVG_Performance['average reconstruction loss to input'][i][j] = loss_1

                Task_AVG_Performance['average nesum in DNNs'][i][j] = loss_5

                Task_AVG_Performance['average cor in DNNs'][i][j] = loss_4

                Task_AVG_Performance['average denoising loss'][i][j] = loss_6
                Task_AVG_Performance['average noise score'][i][j] = loss_7
            
           
            else:
                f1 = predict_regression(case,test_label[i])  # [view1,view2]
            # _,_, r2 = predict_regression(case,task)
                Task_AVG_Performance['Concat_method'][i][j] = f1
               



           
            print(j,i,f1)
     
    Task_AVG_Performance_1 = {}
    for key in Task_AVG_Performance:
        Task_AVG_Performance_1[key] = [Task_AVG_Performance[key].mean(0).tolist(),Task_AVG_Performance[key].std(0).tolist()]
  
    write_csv(Task_AVG_Performance_1,"./Syn_output/{}_True_{}_{}.csv".format(args.method,args.num_views,args.epoch),normal = True,half=True)
    print(fail)
 


if __name__ == "__main__":
    torch.set_default_dtype(torch.float64)
    main()
