import numpy as np

import torch
import pdb
import random
from random import sample
import os
from utils import Initialize_Seed,load_dataset,predict_regression,write_csv,predict_cls


from model_cpm.ori_cca import Ori_CCA_fit_transform



from model_cpm.MyDCCA_full_mnist_sample import MyDCCA_full_fit_transform



from model_syn.MyMVTCAE import MVTCAE_fit_transform

from numpy import random



def Concat_Method(multi_view):
    #pdb.set_trace()
    return np.concatenate(multi_view,axis=1)
    #mae, rmse, r2 = predict_regression(emb,label)





def main():
    Initialize_Seed(2)

   
    import argparse
    parser = argparse.ArgumentParser()
    import os 
    

    
    parser.add_argument("--a", type=float, default=5, help="noise loss")
    parser.add_argument("--method", type=str, default="dcca", help="method (concat,cca,dcca,dcca with noise)")
    parser.add_argument("--dataset", type=str, default="CUB_2_0", help="dataset")
    #parser.add_argument("--nonelinear", type=str, default="True", help="nonelinear")
    parser.add_argument("--epoch", type=int, default=200, help="epochs")
    parser.add_argument("--num_views", type=int, default=2, help="num_of_views")
    parser.add_argument("--lr", type=float, default= 10*1e-5 ,help="learning rate")
    #parser.add_argument("--approach", type=str, default="False", help="approach")
    parser.add_argument("--gpu", type=str, default="0", help="gpu")
    

    args = parser.parse_args()
    
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu  

    #pdb.set_trace()
    
    dataset = "./CPM_data/{}_views.npy".format(args.dataset)
    Experiment_Data = np.load(dataset,allow_pickle=True).item()
    train =  Experiment_Data['train']
    test = Experiment_Data['test']
    test_label =  Experiment_Data['test_label']


    print(args.dataset)
    print('in_dims: ',[len(train[i][0]) for i in range(len(train))])
    print('train: {} test: {}'.format(len(train[0]),len(test[0])))
    #exit(0)


    
    
    #pdb.set_trace()
    Task_AVG_Performance = {}
    tasks_num,cases_num =1,1
    Task_AVG_Performance['God_method'] = np.zeros((tasks_num,cases_num))
 
    
   
    Task_AVG_Performance['Concat_method'] =  np.zeros((tasks_num,cases_num))
    
    Task_AVG_Performance['average view score'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average nesum in feature'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average cor in feature'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average reconstruction loss to input'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average nesum in DNNs'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average cor in DNNs'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average denoising loss'] =  np.zeros((tasks_num,cases_num))

    Task_AVG_Performance['average noise score'] =  np.zeros((tasks_num,cases_num))

 
    fail = []
    #pdb.set_trace()
    
         
    no_concat = False
    for j in range(1):
        # if j!=0:
        #     continue
         
        if args.method=='dcca_with_noise':

            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='cca',lr=args.lr,a = args.a) #150 100

        elif args.method=='dcca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',lr=args.lr) #150 100
        elif args.method=='linear_cca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',lr=args.lr,linear=True) #150 100
        elif args.method=='linear_gcca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',lr=args.lr,linear=True) 
        elif args.method=='dccae':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr) #150 100
        elif args.method=='dcca_private':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,private=True,lr=args.lr) #150 100
        # elif args.method=='dccae_noise_1':
        #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr,noise_type='noise_1') #150 100
        elif args.method=='dgcca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',lr=args.lr) #150 100
        elif args.method=='dgcca_with_noise':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='gcca',lr=args.lr,a = args.a) #150 100
        elif args.method=='dgccae':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,lr=args.lr) #150 100
        elif args.method=='dgcca_private':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,private=True,lr=args.lr) #150 100
        # elif args.method=='dccae_mma':
        #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr,regular=True) #150 100
        # elif args.method=='dgccae_decov':
        #     case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,lr=args.lr,regular=True) #150 100
            
        elif args.method=='cca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='cca')
            case1 = [case1,0,0]
            
        elif args.method=='kcca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='kcca')
            case1 = case1
            
        elif args.method=='prcca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='prcca')
            case1 = case1

        elif args.method=='mvtcae':
            case1 = MVTCAE_fit_transform(train=train,test=test,epochs=args.epoch) #30
            case1 = case1
            no_concat = True
        
            #no_concat = True
        elif args.method=='concat':
            case1 = test
            
            
        else:
            print('bad method')
            print(args.method)
            exit(0)
            
        case = case1 
        
    
        for i in range(1):
          
            f1 = 0
            Task_AVG_Performance['God_method'][i][j] = f1
            
            
            #pdb.set_trace()
            if not no_concat:
                
                try:
                    fea,f_score,loss_1,loss_2,loss_3,loss_4,loss_5,loss_6,loss_7  = case 
                except:
                    f1= predict_cls(Concat_Method(case),test_label)  # [view1,view2]
                    Task_AVG_Performance['Concat_method'][i][j] = f1
                    continue

                f1= predict_cls(Concat_Method(fea),test_label)  # [view1,view2]
            
                Task_AVG_Performance['Concat_method'][i][j] = f1

                Task_AVG_Performance['average view score'][i][j] = f_score

                Task_AVG_Performance['average nesum in feature'][i][j] = loss_3

                Task_AVG_Performance['average cor in feature'][i][j] = loss_2

                Task_AVG_Performance['average reconstruction loss to input'][i][j] = loss_1

                Task_AVG_Performance['average nesum in DNNs'][i][j] = loss_5

                Task_AVG_Performance['average cor in DNNs'][i][j] = loss_4

                Task_AVG_Performance['average denoising loss'][i][j] = loss_6
                Task_AVG_Performance['average noise score'][i][j] = loss_7

            else:
                case  = case 
                f1 = predict_cls(case,test_label)  # [view1,view2]
            # _,_, r2 = predict_regression(case,task)
                Task_AVG_Performance['Concat_method'][i][j] = f1
           


           
        print(Task_AVG_Performance['Concat_method'][0:1,j].mean(0))
      
    Task_AVG_Performance_1 = {}
    for key in Task_AVG_Performance:
        Task_AVG_Performance_1[key] = [Task_AVG_Performance[key][0:1,:].mean(0).tolist(),Task_AVG_Performance[key][0:1,:].mean(0).tolist()]
 
    write_csv(Task_AVG_Performance_1,"./CPM_OUTPUT_NEW/{}_{}_True_{}_{}.csv".format(args.dataset,args.method,args.num_views,args.epoch,args.a),normal = True,half=True)
    print(fail)
    #pdb.set_trace()
            


if __name__ == "__main__":
    torch.set_default_dtype(torch.float64)
    main()
