import numpy as np

import torch
import pdb
import random
from random import sample
import os
from utils import Initialize_Seed,load_dataset,predict_regression,write_csv,predict_cls


from model_poly.ori_cca import Ori_CCA_fit_transform
#from model.ori_cca import Ori_CCA_fit_transform



from model_poly.MyDCCA_full_mnist_sample import MyDCCA_full_fit_transform



from model_syn.MyMVTCAE import MVTCAE_fit_transform

from numpy import random



def Concat_Method(multi_view):
    #pdb.set_trace()
    return np.concatenate(multi_view,axis=1)
    #mae, rmse, r2 = predict_regression(emb,label)





def main():
    Initialize_Seed(2)

   
    import argparse
    parser = argparse.ArgumentParser()
    import os 
    
    
    parser.add_argument("--a", type=float, default=5, help="noise loss")
    parser.add_argument("--method", type=str, default="dcca", help="method (concat,cca,dcca,dcca with noise)")
  
    parser.add_argument("--epoch", type=int, default=200, help="epochs")
    parser.add_argument("--num_views", type=int, default=2, help="num_of_views")
    parser.add_argument("--lr", type=float, default= 10*1e-5 ,help="learning rate")
   
    parser.add_argument("--gpu", type=str, default="0", help="gpu")

    args = parser.parse_args()

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu  

    #pdb.set_trace()
    
    args.dataset = "./Polymnist_data/{}_views.npy".format(args.num_views)
    Experiment_Data = np.load(args.dataset,allow_pickle=True).item()
    train =  Experiment_Data['train']
    test = Experiment_Data['test']
    test_label =  Experiment_Data['test_label']


    
    
    #pdb.set_trace()
    Task_AVG_Performance = {}
    Task_AVG_Performance['God_method'] = np.zeros((1,1))
  
    Task_AVG_Performance['Concat_method'] = np.zeros((1,1))
    #Task_AVG_Performance['average view score'] = np.zeros((len(Tasks),len(Mlt_Data_Cases)))
    Task_AVG_Performance['average view score'] = np.zeros((1,1))
   
    fail = []
    #pdb.set_trace()
    no_concat = False
    for j in range(1):
        # if j!=0:
        #     continue
         
        if args.method=='dcca_with_noise':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='cca',lr=args.lr,a = args.a) #150 100
        elif args.method=='dcca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',lr=args.lr) #150 100
        elif args.method=='dccae':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,lr=args.lr) #150 100
        elif args.method=='dcca_private':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='cca',recon=True,private=True,lr=args.lr) #150 100
        elif args.method=='dgcca':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',lr=args.lr) #150 100
        elif args.method=='dgcca_with_noise':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='normal',loss_name='gcca',lr=args.lr,a = args.a) #150 100
        elif args.method=='dgccae':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,lr=args.lr) #150 100
        elif args.method=='dgcca_private':
            case1 =  MyDCCA_full_fit_transform(train=train,test=test,epochs=args.epoch,num_views=args.num_views,noise='none',loss_name='gcca',recon=True,private=True,lr=args.lr) #150 100
            
        elif args.method=='cca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='cca')
            case1 = [case1,0]
            #no_concat = True
            
        elif args.method=='kcca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='kcca')
            case1 = [case1,0]
            #no_concat = True
            
        elif args.method=='prcca':
            case1 = Ori_CCA_fit_transform(multi_view_train=train,multi_view_test=test,dim=test[0].shape[1],method='prcca')
            case1 = [case1,0]
            #no_concat = True

        elif args.method=='mvtcae':
            case1 = MVTCAE_fit_transform(train=train,test=test,epochs=args.epoch) #30
            case1 = [case1,0]
            no_concat = True
        
        elif args.method=='concat':
            case1 = [test,0]
            #no_concat = True
                
        else:
            print('bad method')
            print(args.method)
            exit(0)
       
        
    
        for i in range(1):
         
            f1 = 0
            Task_AVG_Performance['God_method'][i][j] = f1
            
            
          

            case =case1
            if not no_concat:
                # if i==50 and j==0:
                #    pdb.set_trace()
                
                case,f_score = case 


                f1 = predict_cls(Concat_Method(case),test_label)
              
                Task_AVG_Performance['Concat_method'][i][j] = f1

                Task_AVG_Performance['average view score'][i][j] = f_score

              
           
            else:
              
                f1 = predict_cls(case,test_label)  # [view1,view2]
    
                Task_AVG_Performance['Concat_method'][i][j] = f1
              



         
        print(Task_AVG_Performance['Concat_method'][0:1,j].mean(0))
    
    Task_AVG_Performance_1 = {}
    for key in Task_AVG_Performance:
        Task_AVG_Performance_1[key] = [Task_AVG_Performance[key][0:1,:].mean(0).tolist(),Task_AVG_Performance[key][0:1,:].mean(0).tolist()]
  
    write_csv(Task_AVG_Performance_1,"./Polymnist_output/{}_True_{}_{}.csv".format(args.method,args.num_views,args.epoch),normal = True,half=True)
    print(fail)
    
            


if __name__ == "__main__":
    torch.set_default_dtype(torch.float64)
    main()
