import os
from re import A
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle
from more_itertools import chunked
#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
#from transformers import data

#robustdg
from utils.helper import *
from utils.match_function import *

import random
import torchvision
import datetime

dom_test=[0,90] #rotmnistspur
#dom_test=[0,90]
#[-1]
#["0"]   

# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='fashion_mnist', #'fashion_mnist'
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='matchdg_ctr', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=50, 
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=30, ###30
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0, 
                    help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0, 
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0, 
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, 
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=3, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()


class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x
class Model_erm(nn.Module):
    def __init__(self,feat,fc):
        super(Model_erm, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output


class BaseAlgo_matchdg():
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        if args.method_name == 'matchdg_ctr':
            self.val_dataset= val_dataset
        else:
            self.val_dataset= val_dataset['data_loader']
        self.test_dataset= test_dataset['data_loader']
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        
        self.cuda= cuda
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)    
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]
        
        # Differentially Private Noise
        
    
    def get_model(self):
        
          
        from models.resnet import get_resnet
        if self.args.method_name in ['csd', 'matchdg_ctr']:
            fc_layer=0
        else:
            fc_layer= self.args.fc_layer
        phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer, 
                        self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
                
        phi=phi.to(self.cuda) 
        return phi
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        np.save( self.base_res_dir + '/Val_Acc' + '.npy', np.array(self.val_acc) )
        np.save( self.base_res_dir + '/Test_Acc' +  '.npy', np.array(self.final_acc))
    
        
    
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        elif self.args.opt == 'adam':
            opt= optim.Adam([
                        {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
                ], lr= self.args.lr)
        
        return opt

    ##################################################
    ##################################################
    def get_match_function(self, inferred_match, phi):
        
        data_matched, domain_data, _= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, phi, self.args.match_case, self.args.perfect_match, inferred_match )
        
        
        # Randomly Shuffle the list of matched data indices and divide as per batch sizes
        random.shuffle(data_matched)
        data_matched= list(chunked(data_matched, self.args.batch_size))
        
        return data_matched, domain_data

    def get_match_function_batch(self, batch_idx):
        curr_data_matched= self.data_matched[batch_idx]
        print('curr_data_matched')
        print(curr_data_matched)
        curr_batch_size= len(curr_data_matched)

        data_match_tensor=[]
        label_match_tensor=[]
        for idx in range(curr_batch_size):
            data_temp=[]
            label_temp= []
            for d_i in range(len(curr_data_matched[idx])):
                # print('len(curr_data_matched[idx]')
                # print(len(curr_data_matched[idx]))
                # print('di')
                # print(d_i)
                # print('curr_data_matched[idx][d_i]')
                # print(curr_data_matched[idx][d_i])
                key= random.choice( curr_data_matched[idx][d_i] )
                data_temp.append(self.domain_data[d_i]['data'][key])
                label_temp.append(self.domain_data[d_i]['label'][key])
            
            data_match_tensor.append( torch.stack(data_temp) )
            label_match_tensor.append( torch.stack(label_temp) )                    

        data_match_tensor= torch.stack( data_match_tensor ) 
        label_match_tensor= torch.stack( label_match_tensor )
#         print('Shape: ', data_match_tensor.shape, label_match_tensor.shape)
        
        return data_match_tensor, label_match_tensor, curr_batch_size
    #########################################
    #########################################
    
    def get_test_accuracy(self, case):
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                self.opt.zero_grad()
#                 print(x_e.shape)
#                 print(torch.cuda.memory_allocated())                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.phi(x_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                
                # To avoid CUDA memory issues
                if self.args.dp_noise:
                    self.opt.zero_grad()

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size
    
class MatchDG(BaseAlgo_matchdg):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, cuda,ctr_phase=1):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, cuda) 
        
        self.ctr_phase= ctr_phase
        
        
    def train(self):
        # Initialise and call train functions depending on the method's phase
        if self.ctr_phase:
            self.train_ctr_phase()
        else:
            self.train_erm_phase()
           
    def save_model_ctr_phase(self, epoch):
        ###########################################
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')

    def save_model_erm_phase(self, run):
                
                
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir +  '/Model'+ '.pth')
    
    def init_erm_phase(self):
            
                
           
            from models.resnet import get_resnet
            fc_layer=0                
            ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env).to(self.cuda)
            
            
            # Load MatchDG CTR phase model from the saved weights
           
           
            save_path= self.base_res_dir + '/Model' +  '.pth'
            ctr_phi.load_state_dict( torch.load(save_path) )
            ctr_phi.eval()

            #Inferred Match Case
            if self.args.match_case == -1:
                inferred_match=1
            # x% percentage match initial strategy 
            else:
                inferred_match=0                
                
            data_matched, domain_data= self.get_match_function(inferred_match, ctr_phi)

            return data_matched, domain_data

    def train_erm_phase(self):
        
        for run_erm in range(self.args.n_runs_matchdg_erm):   
            
            self.max_epoch= -1
            self.max_val_acc= 0.0
            for epoch in range(25):    
                
                if epoch ==0:
                    self.data_matched, self.domain_data= self.init_erm_phase()
                elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
                    inferred_match= 1
                    self.data_match_tensor, self.label_match_tensor= self.get_match_function(inferred_match, self.phi)
                    ##############################################ctr阶段不选batch训练模型来match，erm阶段sample batch
                penalty_erm=0
                penalty_erm_extra=0
                penalty_ws=0
                train_acc= 0.0
                train_size=0
                
                #Batch iteration over single epoch
                for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
                    
                    print('Batch Idx: ', batch_idx)

                    self.opt.zero_grad()
                    loss_e= torch.tensor(0.0).to(self.cuda)

                    x_e= x_e.to(self.cuda)
                    y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                    d_e= torch.argmax(d_e, dim=1).numpy()

                    #Forward Pass
                    out= self.phi(x_e)
                    #print(out.shape)
                    erm_loss_extra= F.cross_entropy(out, y_e.long()).to(self.cuda)
                    penalty_erm_extra += float(erm_loss_extra)

                    wasserstein_loss=torch.tensor(0.0).to(self.cuda)
                    erm_loss= torch.tensor(0.0).to(self.cuda) 
                    if epoch > self.args.penalty_s:
                        # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                        total_batch_size= len(self.data_matched)
                        if batch_idx >= total_batch_size:
                            break
                            
                        # Sample batch from matched data points
                        data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)                        
                        data_match= data_match_tensor.to(self.cuda)
                        data_match= data_match.flatten(start_dim=0, end_dim=1)
                        feat_match= self.phi( data_match )

                        label_match= label_match_tensor.to(self.cuda)
                        label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )

                        erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
                        penalty_erm+= float(erm_loss) 
                        
                        train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
                        train_size+= label_match.shape[0]                        

                        # Creating tensor of shape ( domain size, total domains, feat size )
                        feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                        label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                        #Positive Match Loss
                        pos_match_counter=0
                        for d_i in range(feat_match.shape[1]):
            #                 if d_i != base_domain_idx:
            #                     continue
                            for d_j in range(feat_match.shape[1]):
                                if d_j > d_i:                        
                                    if self.args.pos_metric == 'l2':
                                        wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) 
                                    elif self.args.pos_metric == 'l1':
                                        wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )        
                                    elif self.args.pos_metric == 'cos':
                                        wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )

                                    pos_match_counter += feat_match.shape[0]

                        wasserstein_loss = wasserstein_loss / pos_match_counter
                        penalty_ws+= float(wasserstein_loss)                            


                        loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss
                        loss_e += erm_loss
                        loss_e += erm_loss_extra

                    loss_e.backward(retain_graph=False)
                    self.opt.step()

                    del erm_loss_extra
                    del erm_loss
                    del wasserstein_loss 
                    del loss_e
                    torch.cuda.empty_cache()

                print('Train Loss Basic : ', penalty_erm_extra,  penalty_erm, penalty_ws )
                print('Train Acc Env : ', 100*train_acc/train_size )
                print('Done Training for epoch: ', epoch)    
                
                #Train Dataset Accuracy
                self.train_acc.append( 100*train_acc/train_size )
            
                #Val Dataset Accuracy
                self.val_acc.append( self.get_test_accuracy('val') )

                #Test Dataset Accuracy
                self.final_acc.append( self.get_test_accuracy('test') ) 
                
                #Save the model if current best epoch as per validation loss#这里最好的epoch 用的是val不是test
                if self.val_acc[-1] > self.max_val_acc:
                    self.max_val_acc= self.val_acc[-1]
                    self.max_epoch= epoch
                    self.save_model_erm_phase(run_erm)
                
                print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])


    def train_ctr_phase(self):
        
        self.max_epoch= -1
        self.max_val_score= 0.0
        for epoch in range(30):    
            
            if epoch ==0:
                inferred_match= 0                
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
                #print('have already matched')
            elif (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
                inferred_match= 1
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
            
            penalty_same_ctr=0
            penalty_diff_ctr=0
            penalty_same_hinge=0
            penalty_diff_hinge=0           
            train_acc= 0.0
            train_size=0
            
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
                
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)            

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).numpy()

                same_ctr_loss = torch.tensor(0.0).to(self.cuda)
                diff_ctr_loss = torch.tensor(0.0).to(self.cuda)
                same_hinge_loss = torch.tensor(0.0).to(self.cuda)
                diff_hinge_loss = torch.tensor(0.0).to(self.cuda)
                
                if epoch > self.args.penalty_s:
                    # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                    total_batch_size= len(self.data_matched)
                    if batch_idx >= total_batch_size:
                        break
                    
                    # Sample batch from matched data points
                    data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
                        
                    data_match= data_match_tensor.to(self.cuda)
                    data_match= data_match.flatten(start_dim=0, end_dim=1)
                    feat_match= self.phi( data_match )
            
                    label_match= label_match_tensor.to(self.cuda)
                    label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )                    
                    
                    # Creating tensor of shape ( domain size, total domains, feat size )
                    feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                    label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                    # Contrastive Loss
                    same_neg_counter=1
                    diff_neg_counter=1
                    for y_c in range(self.args.out_classes):

                        pos_indices= label_match[:, 0] == y_c
                        neg_indices= label_match[:, 0] != y_c
                        pos_feat_match= feat_match[pos_indices]
                        neg_feat_match= feat_match[neg_indices]
        
                        if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
                            continue

                        # Iterating over anchors from different domains
                        for d_i in range(pos_feat_match.shape[1]):
                            if torch.sum( torch.isnan(neg_feat_match) ):
                                print('Non Reshaped X2 is Nan')
                                sys.exit()

                            diff_neg_feat_match= neg_feat_match.view(  neg_feat_match.shape[0]*neg_feat_match.shape[1], neg_feat_match.shape[2] )

                            if torch.sum( torch.isnan(diff_neg_feat_match) ):
                                print('Reshaped X2 is Nan')
                                sys.exit()

                            neg_dist= embedding_dist( pos_feat_match[:, d_i, :], diff_neg_feat_match[:, :], self.args.pos_metric, self.args.tau, xent=True)     
                            if torch.sum(torch.isnan(neg_dist)):
                                print('Neg Dist Nan')
                                sys.exit()

                            # Iterating pos dist for current anchor
                            for d_j in range(pos_feat_match.shape[1]):
                                if d_i != d_j:
                                    pos_dist= 1.0 - embedding_dist( pos_feat_match[:, d_i, :], pos_feat_match[:, d_j, :], self.args.pos_metric )
                                    pos_dist= pos_dist / self.args.tau
                                    if torch.sum(torch.isnan(neg_dist)):
                                        print('Pos Dist Nan')
                                        sys.exit()

                                    if torch.sum( torch.isnan( torch.log( torch.exp(pos_dist) + neg_dist ) ) ):
                                        print('Xent Nan')
                                        sys.exit()

    #                                 print( 'Pos Dist', pos_dist )
    #                                 print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist ))
                                    diff_hinge_loss+= -1*torch.sum( pos_dist - torch.log( torch.exp(pos_dist) + neg_dist ) )                                 
                                    diff_ctr_loss+= torch.sum(neg_dist)
                                    diff_neg_counter+= pos_dist.shape[0]

                    same_ctr_loss = same_ctr_loss / same_neg_counter
                    diff_ctr_loss = diff_ctr_loss / diff_neg_counter
                    same_hinge_loss = same_hinge_loss / same_neg_counter
                    diff_hinge_loss = diff_hinge_loss / diff_neg_counter      

                    penalty_same_ctr+= float(same_ctr_loss)
                    penalty_diff_ctr+= float(diff_ctr_loss)
                    penalty_same_hinge+= float(same_hinge_loss)
                    penalty_diff_hinge+= float(diff_hinge_loss)
                
                    loss_e += ( ( epoch- self.args.penalty_s )/(self.args.epochs -self.args.penalty_s) )*diff_hinge_loss
                        
                if not loss_e.requires_grad:
                    continue
                    
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del same_ctr_loss
                del diff_ctr_loss
                del same_hinge_loss
                del diff_hinge_loss
                torch.cuda.empty_cache()
   
            print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr, penalty_same_hinge, penalty_diff_hinge)
            print('Done Training for epoch: ', epoch)
                        
            if (epoch+1)%5 == 0:#used to be %5
                                
                from evaluation.match_eval import MatchEval
                test_method= MatchEval(
                                   self.args, self.train_dataset, self.val_dataset,
                                   self.test_dataset, self.base_res_dir, self.cuda
                                  )   
                #Compute test metrics: Mean Rank
                test_method.phi= self.phi
                test_method.get_metric_eval()
                                
                # Save the model's weights post training
                if test_method.metric_score['TopK Perfect Match Score'] > self.max_val_score: 
                    self.max_val_score= test_method.metric_score['TopK Perfect Match Score']
                    self.max_epoch= epoch
                    self.save_model_ctr_phase(epoch)#here the model  
        
class BaseAlgo_erm():
    def __init__(self, args, train_dataset,base_res_dir, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)
          
        self.train_acc=[]
        
    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
        
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        elif self.args.opt == 'adam':
            opt= optim.Adam([
                        {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
                ], lr= self.args.lr)
        
        return opt

class Erm(BaseAlgo_erm):
    def __init__(self, args, train_dataset,  base_res_dir,  cuda):
        
        super().__init__(args, train_dataset, base_res_dir, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0        
        for epoch in range(50):   
            
            
            penalty_erm=0
            train_acc= 0.0
            train_size=0
    
            for batch_idx, (x_e, y_e ,d_e, idx_e,obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                
                #Forward Pass
                out= self.phi(x_e)
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
        # Save the model's weights post training
        self.save_model()   


def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            L[i, j] = cov_function(items[i],items[j])             #cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    """
    exponential quadratic covariance function
    """
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f
    
def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X

class BaseLoader(data_utils.Dataset):
    def __init__(self, args, list_domains, root, transform=None, data_case='train'):
        self.args= args
        self.list_domains = list_domains
        
        self.root = 'data/datasets' + root
        self.transform = transform
        self.data_case = data_case
        
        
        self.base_domain_size= 0
        self.list_size=[]
        self.data= [] 
        self.labels= [] 
        self.domains= [] 
        self.indices= [] 
        self.objects= []

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        d = self.domains[index]
        idx = self.indices[index]
        objs= self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, y, d, idx, objs

    def get_size(self):
        return self.labels.shape[0]
     
class L2loaderEval(BaseLoader):
    def __init__(self, args, list_domains, mnist_subset, root,match_func,transform=None, data_case='train', download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        self.match_func=match_func
        
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()
            
    def _get_data(self):    

        # Choose subsets that should be included into the training
        list_img = {'aug':[], 'org':[] }
        list_labels = {'aug':[], 'org':[] }
        list_idx= {'aug':[], 'org':[] }
        list_size= {'aug':0, 'org':0 }
        list_classes={'aug':[], 'org':[] }
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'
            
        image_counter= 0
        for domain in self.list_domains:
            
            load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
            mnist_imgs= torch.load( load_dir +  '_data.pt')
            mnist_imgs_org= torch.load( load_dir +  '_org_data.pt')
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_idx= image_counter + np.array(list(range(len(mnist_imgs))))
            mnist_idx= mnist_idx.tolist()
            image_counter+= len(mnist_imgs)
            
            #print('Source Domain ', domain)
            list_img['aug'].append(mnist_imgs)            
            list_img['org'].append(mnist_imgs_org)      
                        
            list_labels['aug'].append(mnist_labels)
            list_labels['org'].append(mnist_labels)
            
            list_idx['aug'].append( mnist_idx )            
            list_idx['org'].append( mnist_idx )            
            
            list_size['aug']+= mnist_imgs.shape[0]
            list_size['org']+= mnist_imgs_org.shape[0]    
            
        if self.match_func:
            #print('Match Function Updates')
            num_classes= 10
            for y_c in range(num_classes):
                for key in ['aug', 'org']:
                    base_class_size=0
                    base_class_idx=-1
                    
                    curr_class_size=0                    
                    for d_idx, domain in enumerate( self.list_domains ):
                        class_idx= list_labels[key][d_idx] == y_c
                        curr_class_size+= list_labels[key][d_idx][class_idx].shape[0]
                        
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        if key == 'aug':
                            base_class_idx= 0
                        else:
                            base_class_idx= 1                            
                        
                self.base_domain_size += base_class_size
                # print(self.base_domain_size)
                # print('######################')
                #print('Max Class Size: ', base_class_size, ' Base Domain Idx: ', base_class_idx, ' Class Label: ', y_c )
                   
        # Stack
        data_imgs = torch.cat(list_img['aug'] + list_img['org'] )
        data_labels = torch.cat(list_labels['aug'] + list_labels['org'] )
        data_indices = np.array(list_idx['aug']+list_idx['org']) 
        data_indices= np.hstack(data_indices)
        list_classes= list_classes['aug'] + list_classes['org']
        self.training_list_size = [ list_size['aug'],  list_size['org'] ]           
           
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
            
        # Create domain labels
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.training_list_size)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size
                    
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        data_indices = data_indices[inds]
        data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.training_list_size))
        data_domains = d[data_domains]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects

class L2loader(BaseLoader):
    def __init__(self, args, list_domains, mnist_subset, root, indecies,match_func,transform=None, data_case='train',  download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        self.match_func=match_func
        self.indecies=indecies###只用于sample测试集
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()#####################

    def _get_data(self):
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        if self.data_case=='train':
            for i in range(len(self.list_domains)):
                domain=self.list_domains[i]
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
                mnist_imgs= torch.load( load_dir +  '_org_data.pt')
                mnist_labels= torch.load( load_dir +  '_label.pt')
                res=self.indecies[i*40:(i+1)*40]
                # print('res')
                # print(res)
                inx=[]
                curr_data=0
                for j in res:
                    if j==1:
                        a=curr_data+np.array(list(range(50)))
                        # a.tolist()
                        # print('a')
                        # print(a.tolist())
                        inx.append(a.tolist())
                    curr_data+=50
                INX=[]
                #print(inx)
                for k in inx:
                    INX+=k
                #print(INX)
                INX=np.array(INX)
                mnist_labels = mnist_labels[INX]
                mnist_imgs = mnist_imgs[INX]
            
                mnist_idx= list(range(len(mnist_imgs)))
                # print('len(mnist imgs')
                # print(len(mnist_imgs))
                #print('Source Domain ', domain)
                list_img.append(mnist_imgs)
                list_labels.append(mnist_labels)
                list_idx.append(mnist_idx)
                list_size.append(mnist_imgs.shape[0])
        else:
            for domain in self.list_domains:
                
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)###########################
                
                mnist_imgs= torch.load( load_dir +  '_org_data.pt')############################
                mnist_labels= torch.load( load_dir +  '_label.pt')
                mnist_idx= list(range(len(mnist_imgs)))
               
                list_img.append(mnist_imgs)
                list_labels.append(mnist_labels)
                list_idx.append(mnist_idx)
                list_size.append(mnist_imgs.shape[0])

        if self.match_func:
            #print('Match Function Updates')
            num_classes= 10
            for y_c in range(num_classes):
                base_class_size=0
                base_class_idx=-1
                for d_idx, domain in enumerate( self.list_domains ):
                    class_idx= list_labels[d_idx] == y_c
                    curr_class_size= list_labels[d_idx][class_idx].shape[0]
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        base_class_idx= d_idx
                self.base_domain_size += base_class_size
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels) ####把domains的数据合并
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        
        
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        data_indices = data_indices[inds]
        data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects
class L1loader(BaseLoader):
    def __init__(self, args, list_domains, mnist_subset, root, transform=None, data_case='train', download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()#####################
        ############inds 用于记录打乱顺序
        
    def _get_data(self):
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        

        for domain in self.list_domains:
            
            load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)###########################
            
            mnist_imgs= torch.load( load_dir +  '_org_data.pt')############################
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_idx= list(range(len(mnist_imgs)))
            
            #print('Source Domain ', domain)
            list_img.append(mnist_imgs)
            list_labels.append(mnist_labels)
            list_idx.append(mnist_idx)
            list_size.append(mnist_imgs.shape[0])

                   
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels) ####把domains的数据合并
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        

        data_domains=data_domains.long()
        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects

def get_dataloader_l2(args, domains, data_case, eval_case, kwargs,indecies):
    ######
    # if data_case == 'train' and args.method_name=='erm':
    #     match_func=False ########### 在分batch之前决定数据，适用于md和fish，batch size 50，5个domain 200个batch 然后相同打乱，根据筛选的batch知道选中哪些数据
    #     batch_size= 50 ##################
    #     ##############################
    #     ###########################
    if data_case=='train':
        match_func=True
        batch_size=50
    else:
        match_func=False            
        batch_size= 512
    flag=0
    # print('batchsize')
    # print(batch_size)
    if data_case == 'test':
        mnist_subset= 9
    else:
        mnist_subset= 0

    if eval_case:
        if args.test_metric in ['match_score'] and args.match_func_aug_case:
            print('Match Function evaluation on self augmentations')
            data_obj=  L2loaderEval(args, domains, mnist_subset, '/mnist/', data_case=data_case, match_func=True)
            flag=1
        else:
            data_obj=  L2loader(args, domains, mnist_subset, '/mnist/',indecies, match_func=match_func,data_case=data_case)
    else:
        data_obj=  L2loader(args, domains, mnist_subset, '/mnist/',indecies, data_case=data_case, match_func=match_func)
   
    dataset={}
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=True, **kwargs )
    
    dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    dataset['base_domain_size']= data_obj.base_domain_size     
    
    dataset['domain_size_list']= data_obj.training_list_size   
    # if flag==1:
    #     print('base_domain_size')  
    #     print(data_obj.base_domain_size)
    #     print('training_list_size') 
    #     print(data_obj.training_list_size) 
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    
    if eval_case and args.test_metric in ['match_score'] and args.match_func_aug_case:
        dataset['total_domains']= 2
        dataset['domain_list']= ['aug', 'org']
    
    return dataset
def get_dataloader_l1(args, domains, data_case, eval_case, kwargs):
    
    dataset={}

##here

        
    # if data_case == 'train' and args.method_name=='erm':
    #     match_func=False ########### 在分batch之前决定数据，适用于md和fish，batch size 50，5个domain 200个batch 然后相同打乱，根据筛选的batch知道选中哪些数据
    #     batch_size= 50 ##################
    #     ##############################
    #     ###########################
    # elif data_case=='train' and args.method_name in ['matchdg_ctr','matchdg_erm']:
    #     match_func=False
    #     batch_size=50
    # else:
    #     match_func=False            
    #     batch_size= 512
    
    # # Set match_func to True in case of test metric as match_score
          
    # if data_case == 'test':
    #     # Actually by default the seeds 0, 1, 2 are for training and seed 9 is for test; mention that properly in comments
    #     mnist_subset= 9
    # else:
    mnist_subset= 0
    batch_size=50
    
    #print('MNIST Subset: ', mnist_subset)
    data_obj=  L1loader(args, domains, mnist_subset, '/matchdg/mnist/', data_case=data_case)###rot_mnist_spur的这个要改

 
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=False, **kwargs )
    
    dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    dataset['base_domain_size']= data_obj.base_domain_size       
    dataset['domain_size_list']= data_obj.training_list_size    
    
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    
    if eval_case and args.test_metric in ['match_score'] and args.match_func_aug_case:
        dataset['total_domains']= 2
        dataset['domain_list']= ['aug', 'org']

    return dataset


cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 1, 'pin_memory': False} 
else:
    kwargs= {}
final_accuracy_target_val=[]
final_accuracy_source_val=[]
#fashion mnist
domains_list_fm=[ [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]
#[[42, 46, 57, 63, 74], [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65], [49, 50, 55, 57, 60],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46]]
#[[42, 46, 57, 63, 74], [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65], [49, 50, 55, 57, 60],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]

domains_list_rm=[[17, 39, 51, 62, 68],[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]
#[[28, 37, 43, 64, 74], [17, 19, 21, 42, 57], [47, 48, 54, 65, 70], [20, 23, 34, 41, 68], [24, 49, 54, 67, 74], [28, 30, 51, 61, 69], [22, 26, 41, 49, 51], [16, 17, 20, 66, 71], [21, 41, 43, 49, 75], [39, 51, 56, 65, 73]]    
for domains in domains_list_fm:


    print(domains)
    

    res_dir= 'results_matchdg/l2_180/'
    base_res_dir=(                    #####这里
                    res_dir + args.dataset_name + '/' #+ args.method_name + '/' + args.match_layer + '/' + 'train_' 
                    + str(domains)
                )
    
    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir) 

    # dataset1=get_dataloader_l1(args, domains, 'train', 0, kwargs)
    # train_method= Erm(args, dataset1,  base_res_dir, cuda)

    # train_method.train()

    # data_loc=[]
    # for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset1['data_loader']):
    # # for data in data_imgs:
    #     with torch.no_grad():
    #         x_e= x_e.to(cuda)
    #         location=train_method.phi.featurizer(x_e)                
    #         location=torch.mean(location,0) 
    #         location=location.cpu().numpy()###第二层也可以变成新的计算相似度的方法
    #     data_loc.append(location)
    # print(batch_idx)
    # L=build_similary_matrix(exp_quadratic(0.1),data_loc)#160取100
    # X=sample_k(data_loc,L,115) ##182 rm, 105 fm 

    base_load_index_dir='/home/skyfall/rdd/rdd/indecies_180'
    residual_dir='/'+str(domains)+'/indecies.npy'
    indecies=np.load(base_load_index_dir+residual_dir)
    count=0
    for i in indecies:
        count+= i==1
    print(count)
    
    #print(X)
    
    # indecies=[]
    # for i in range(0,len(X)):
    #             if X[i]==True:
    #                 indecies.append(i)
    #print(indecies)
    dataset2=get_dataloader_l2(args, domains, 'train', 0, kwargs,indecies)#args, domains, data_case, eval_case, kwargs,indecies
    #dataset2=dataset1[['data_loader']][X]

    #val_dataset= get_dataloader( args, 0, domains, 'val', 0, kwargs ) 
    val_dataset= get_dataloader_l2( args, domains, 'val', 1, kwargs,indecies)  
    #print('finish val data loading')
        # if args.method_name=='dannin':
        #     test_dataset= get_dataloader( args, (run+1)%3, train_domains, 'val', 0, kwargs )  
        # else:
    test_dataset= get_dataloader_l2( args, args.test_domains, 'test', 0, kwargs,indecies)  
    ctr_phase=1
    train_method= MatchDG(
                            args,dataset2,val_dataset,
                            test_dataset, base_res_dir, 
                            cuda, ctr_phase,
                        )     
       
            
        #Train the method: It will save the model's weights post training and evalute it on test accuracy
    train_method.train()
    parser.set_defaults(match_func_aug_case=0)
    parser.set_defaults(method_name='matchdg_erm')
    parser.set_defaults(batch_size=16)####更改arg里面的dom
    parser.set_defaults(epochs=25)
    parser.set_defaults(match_flag=0)
    parser.set_defaults(match_case=1.0)
    parser.set_defaults(ctr_match_case=0.0)
    args=parser.parse_args()
    ctr_phase=0
    val_dataset= get_dataloader_l2( args, domains, 'val', 0, kwargs,indecies ) ####可能也要indices一下
    train_method= MatchDG(
                            args, dataset2, val_dataset,
                            test_dataset, base_res_dir,
                            cuda, ctr_phase

                            )     
    train_method.train()   
    #Final Report Accuacy
    
    np.save( base_res_dir + '/Val_Acc' + '.npy', np.array(train_method.val_acc) )
    np.save( base_res_dir + '/Test_Acc' +  '.npy', np.array(train_method.final_acc)) 
        #print('11111111111111')
    final_acc= np.max(train_method.final_acc)
    final_accuracy_target_val.append( final_acc )
    
    idx= np.argmax(train_method.val_acc)
    final_acc= train_method.final_acc[idx]
    final_accuracy_source_val.append( final_acc  )
    print(final_accuracy_source_val)  
    print(final_accuracy_target_val)
    parser.set_defaults(match_func_aug_case=1)
    parser.set_defaults(method_name='matchdg_ctr')
    parser.set_defaults(batch_size=64)####更改arg里面的dom
    parser.set_defaults(epochs=1)
    parser.set_defaults(match_flag=1)
    parser.set_defaults(match_case=0.0)
    parser.set_defaults(ctr_match_case=0.01)
    args=parser.parse_args()

print('\n')
print('Done for the Model..')
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
print('\n')

#python level-2-matchdg.py