#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle

#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#robustdg
from utils.helper import *
from utils.match_function import *

import random




dom_test=[0,90] #rotmnistspur
#dom_test=[0,90]
#[-1]
#["0"]   

# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='rot_mnist', 
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='matchdg_ctr', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='cos', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=5e-5,#0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=64,  #erm 阶段变16
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=30,  #erm 变25
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')

parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=1, #erm 0
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=0.0, #erm 换 1.0
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, #erm 0.0
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=3, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=0)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=1, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()


#python train.py --method_name dann


#GPU
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 1, 'pin_memory': False} 
else:
    kwargs= {}

#List of Train; Test domains
final_accuracy_target_val=[]
final_accuracy_source_val=[]

domains_l1_rm=[[17, 39, 51, 62, 68],[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]
#high l1 rm [ [22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73], [16, 19, 23, 47, 65],  [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73],  [22, 32, 33, 72, 75]]
#[15, 26, 37, 42, 70], [17, 39, 51, 62, 68], [18, 22, 39, 60, 63],  [16, 17, 34, 40, 62], [22, 23, 33, 41, 65],  [15, 18, 35, 53, 60], [15, 28, 30, 33, 65],  [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],
domains_l1_fm=[ [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]#[42, 46, 57, 63, 74],[49, 50, 55, 57, 60]本来22个去掉两个和rm一样20个
#刚刚跑完的部分，记得添回去
domains_random_fm=[ [17, 33, 41, 54, 61], [24, 45, 48, 52, 75], [25, 35, 39, 60, 68],[21, 29, 32, 36, 38], [26, 38, 59, 60, 70], [15, 18, 35, 53, 60], [22, 33, 35, 53, 71], [21, 24, 48, 49, 57], [37, 49, 61, 68, 74], [40, 41, 50, 63, 74], [15, 36, 38, 46, 56], [19, 31, 35, 54, 73], [21, 30, 41, 42, 44], [46, 49, 58, 62, 75], [15, 16, 22, 33, 69], [22, 23, 33, 41, 65], [27, 29, 35, 39, 59], [28, 43, 50, 64, 68], [22, 28, 35, 45, 70], [16, 18, 33, 37, 44]]
#刚刚跑完的部分，记得添回去 [17, 33, 41, 54, 61], [24, 45, 48, 52, 75], [25, 35, 39, 60, 68],

#高random[[35, 49, 56, 65, 75], [19, 28, 47, 53, 58], [17, 30, 44, 63, 69], [22, 23, 30, 42, 43], [33, 42, 45, 70, 71], [36, 37, 47, 55, 64], [22, 23, 25, 52, 62], [17, 40, 64, 65, 74], [27, 46, 51, 53, 65], [20, 33, 40, 41, 45], [23, 30, 47, 65, 72], [15, 21, 53, 57, 63], [15, 25, 28, 59, 70], [25, 58, 60, 64, 74], [16, 22, 53, 56, 63], [23, 26, 54, 61, 73], [17, 25, 50, 53, 74], [27, 45, 54, 61, 65], [19, 40, 41, 61, 65], [31, 32, 48, 67, 75]]
#[[32, 42, 51, 53, 64], [25, 43, 46, 59, 64], [30, 44, 48, 62, 75], [51, 53, 60, 64, 72], [25, 28, 51, 72, 74], [21, 51, 52, 55, 66], [30, 50, 59, 66, 67], [15, 24, 43, 54, 64]]
#[[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70]]
#for i in range(20):
#dom_train in domains_list:
#train_domains=[24, 51, 62, 67, 68]
#for i in range(20):
#for domains in [[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61]]:
for domains in domains_l1_rm :
#
    # dom_train=random.sample((list(range(15,76))),5)
    # dom_train.sort()
    #print(dom_train)
    #train_domains= dom_train

    # train_domains=list(range(15,76))
    # print(train_domains)

    train_domains=domains
    print(train_domains)
    test_domains= args.test_domains
    # import datetime
    # runId = datetime.datetime.now().isoformat().replace(':', '_')

#Initialize
    
    if args.os_env:
        res_dir= os.getenv('PT_OUTPUT_DIR') + '/'
    else:
        res_dir= 'results_mmd/l1/'#+runId+'_'#############################################

    if args.dp_noise:
        base_res_dir=(
                    res_dir + args.dataset_name + '/' + 'dp_' +  str(args.dp_epsilon) + '_' + args.method_name + '/' + args.match_layer 
                    + '/' + 'train_' + str(train_domains)
                )    
    else:
        ############################################################################################
        if args.dataset_name=='rot_mnist_spur':
            # print('spur')
            degree=[]
            color=[]
            for per_domain in dom_print:
                degree.append(per_domain['degree'])
                color.append(per_domain['color'])
            
            base_res_dir=(
                    res_dir + args.dataset_name + '/' + args.method_name + '/' + args.match_layer 
                    + '/' + 'train_' + str(degree)+str(color)
                )
        else:
            # print('not spur')
            base_res_dir=(                    #####这里
                    res_dir + args.dataset_name + '_4.0/' #+ args.method_name + '/' + args.match_layer + '/' + 'train_' 
                    + str(train_domains)
                )
    if args.train_domains=='all':
        train_domains=list(range(15,76))
        base_res_dir=(
                res_dir + args.dataset_name + '_new/' + args.method_name + '/' + args.match_layer 
                + '/' + 'train_' + 'ground_set'
            )




    #TODO: Handle slab noise case in helper functions
    
        
    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir)    

    # print('base_res_dir')
    # print(base_res_dir)
    #Execute the method for multiple runs ( total args.n_runs )
    for run in range(args.n_runs):
        print('Run', run)
        #Seed for repoduability
        # random.seed(run*10)
        # np.random.seed(run*10) 
        # torch.manual_seed(run*10)
        # #print(torch.cuda.is_available())    
        # if torch.cuda.is_available():
        #     torch.cuda.manual_seed_all(run*10)    
                
        #DataLoader        
        train_dataset= get_dataloader( args, run, train_domains, 'train', 0, kwargs )    
        #if args.method_name == 'matchdg_ctr':
        val_dataset= get_dataloader( args, run, train_domains, 'val', 0, kwargs )      
        # val_dataset= get_dataloader( args, run, train_domains, 'val', 1, kwargs )           
        #else:
            
        # if args.method_name=='dannin':
        #     test_dataset= get_dataloader( args, (run+1)%3, train_domains, 'val', 0, kwargs )  
        # else:
        test_dataset= get_dataloader( args, run, test_domains, 'test', 0, kwargs )   
        #print('finish loading data') 
    #     print('Train Domains, Domain Size, BaseDomainIdx, Total Domains: ', train_domains, total_domains, domain_size, training_list_size)
        
        #Import the module as per the current training method
        # if args.method_name == 'erm_match' or args.method_name == 'mask_linear' or args.method_name == 'dp_erm':
        #     from algorithms.erm_match import ErmMatch    
        #     train_method= ErmMatch(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir, 
        #                             run, cuda
        #                         )
        # elif args.method_name == 'matchdg_ctr':





        # from algorithms.match_dg import MatchDG
        # ctr_phase=1
        # train_method= MatchDG(
        #                         args, train_dataset, val_dataset,
        #                         test_dataset, base_res_dir, 
        #                         run, cuda, ctr_phase
        #                     )     




        # elif args.method_name == 'matchdg_erm':
            
        # elif args.method_name == 'hybrid':
        #     from algorithms.hybrid import Hybrid
        #     train_method= Hybrid(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir,
        #                             run, cuda
        #                         )        
        # elif args.method_name == 'erm':
        # from algorithms.erm import Erm    
        # train_method= Erm(
        #                         args, train_dataset, val_dataset,
        #                         test_dataset, base_res_dir, 
        #                         run, cuda
        #                     )           
        # elif args.method_name == 'irm':
        #     from algorithms.irm import Irm    
        #     train_method= Irm(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir, 
        #                             run, cuda
        #                         )
        # elif args.method_name == 'dro':
        #     from algorithms.dro import DRO    
        #     train_method= DRO(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir, 
        #                             run, cuda
        #                         )
        # elif args.method_name == 'csd':
        #     from algorithms.csd import CSD   
        #     train_method= CSD(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir, 
        #                             run, cuda
        #                         )
        #eif args.method_name == 'mmd':


        from algorithms.mmd import MMD    
        train_method= MMD(
                                args, train_dataset, val_dataset,
                                test_dataset, base_res_dir, #run, 
                                cuda
                            )     


       #elif args.method_name == 'dann':




        # from algorithms.dann import DANN    
        # train_method= DANN(
        #                         args, train_dataset, val_dataset,
        #                         test_dataset, base_res_dir, 
        #                         run, cuda
        #                     )





        # elif args.method_name == 'dannin':
        #     from algorithms.dann_inverse import DANNIN   
        #     train_method= DANNIN(
        #                             args, train_dataset, val_dataset,
        #                             test_dataset, base_res_dir, 
        #                             run, cuda
        #                         )
        # elif args.method_name == 'fish':
        #     from algorithms.fish import fish    
        #     train_method= fish(
        #                             args,  
        #                             run, cuda,kwargs
        #                         )
            
        #Train the method: It will save the model's weights post training and evalute it on test accuracy




        train_method.train()




        # parser.set_defaults(match_func_aug_case=0)
        # parser.set_defaults(method_name='matchdg_erm')
        # parser.set_defaults(batch_size=16)####更改arg里面的dom
        # parser.set_defaults(epochs=1)
        # parser.set_defaults(match_flag=0)
        # parser.set_defaults(match_case=1.0)
        # parser.set_defaults(ctr_match_case=0.0)
        # args=parser.parse_args()
        # val_dataset= get_dataloader( args, run, train_domains, 'val', 0, kwargs ) 
        # ctr_phase=0
        # train_method= MatchDG(
        #                         args, train_dataset, val_dataset,
        #                         test_dataset, base_res_dir,
        #                         run, cuda, ctr_phase


        #                      )     
        # train_method.train()   





        #Final Report Accuacy
        #if args.method_name != 'matchdg_ctr':
            #print('11111111111111')
        np.save( base_res_dir + '/Val_Acc' + '.npy', np.array(train_method.val_acc) )
        np.save( base_res_dir + '/Test_Acc' +  '.npy', np.array(train_method.final_acc)) 
        final_acc= np.max(train_method.final_acc)
        final_accuracy_target_val.append( final_acc )
        print('final-acc')
        print(train_method.final_acc)
        print('val acc')
        print(train_method.val_acc)
        idx= np.argmax(train_method.val_acc)
        final_acc= train_method.final_acc[idx]
        final_accuracy_source_val.append( final_acc  )
        print(final_accuracy_source_val)  
        print(final_accuracy_target_val)
        # parser.set_defaults(method_name='matchdg_ctr')
        # parser.set_defaults(match_func_aug_case=1)
        # parser.set_defaults(batch_size=64)####更改arg里面的dom
        # parser.set_defaults(epochs=1)
        # parser.set_defaults(match_flag=1)
        # parser.set_defaults(match_case=0.0)
        # parser.set_defaults(ctr_match_case=0.01)
        # args=parser.parse_args()

print('\n')
print('Done for the Model..')
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
print('\n')
