#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle

#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#robustdg
from utils.helper import *
from utils.match_function import *

import random
import torchvision
import datetime


#dom_test=[-1,90] #rotmnistspur
dom_test=[0,90]
#[-1]
#["0"]   
#dom_train=random.sample((list(range(15,76))),40)####20可以改

# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='fashion_mnist', ####'rot_mnist_spur'
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=64, 
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=60, 
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0, 
                    help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0, 
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0, 
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, 
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=1, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=1)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()

class BaseDataLoader(data_utils.Dataset):
    def __init__(self, args, list_domains, root, transform=None, data_case='train'):
        self.args= args
        self.list_domains = list_domains
        
        self.root = 'data/datasets' + root
        self.transform = transform
        self.data_case = data_case
        
        
        self.base_domain_size= 0
        self.list_size=[]
        self.data= [] 
        self.labels= [] 
        self.domains= [] 
        # self.indices= [] 
        # self.objects= []

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        d = self.domains[index]
        # idx = self.indices[index]
        # objs= self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, y, d#, idx, objs

    def get_size(self):
        return self.labels.shape[0]
     

class MnistRotated(BaseDataLoader):
    def __init__(self, args, list_domains, mnist_subset, root, transform=None, data_case='train',  download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        
        self.data, self.labels, self.domains = self._get_data()#####################

    def _get_data(self):
        
        # Choose subsets that should be included into the training
        data_img = []
        data_labels = []
        #list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        
        dom=list(range(15,76,5))
        color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
  
        for domain in self.list_domains:
            if self.args.dataset_name=='rot_mnist_spur':
                degree_idx=int(domain/6)
                col_idx=domain%6
                degree=dom[degree_idx]
                color=color_list[col_idx]
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(degree)+'_color_'+color###########需要加入颜色项
            else:
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
            
            
            res=np.random.choice(2000, 500)
            mnist_imgs= torch.load( load_dir +  '_org_data.pt')
            mnist_imgs=mnist_imgs[res]
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_labels=mnist_labels[res]
            
            #print('Source Domain ', domain)
            data_img.append(mnist_imgs)
            data_labels.append(mnist_labels)
            #list_idx.append(mnist_idx)
            list_size.append(mnist_imgs.shape[0])

       
        # Stack
        data_imgs = torch.cat(data_img)
        data_labels = torch.cat(data_labels) ####把domains的数据合并
        # data_indices = np.array(list_idx)
        # data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        #data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        
        
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        # data_indices = data_indices[inds]
        # data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains#, data_indices, data_objects##########################


def get_dataloader_level_0(args, domains, data_case, kwargs):
    
    dataset={}
    batch_size=64
    
    data_obj=  MnistRotated(args, domains, 0, '/matchdg/mnist/', data_case=data_case)###rot_mnist的位置要改
    
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=True, **kwargs )
    
    #dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    # dataset['base_domain_size']= data_obj.base_domain_size       
    # dataset['domain_size_list']= data_obj.training_list_size    
    
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    return dataset
class get_each_domain_data(data_utils.Dataset):
    def __init__(self, data,lables):
        self.data = data
        self.lables = lables
        self.transform = None ###can be other values
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data.shape)==3:
            self.data= data.unsqueeze(1) 

    def __len__(self):
        return self.lables.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.lables[index]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def get_size(self):
        return self.lables.shape[0]

def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    pxy = np.zeros((len(items), len(items)))
    py = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            
            # pxy[i,j]=cov_function(items[i],items[j])
            # pxy[j,i]=pxy[i,j]
            # ##print('distance betwwen domain-side(partial): '+str(pxy))都是不大于一的正实数
            # py[i,j]=cov_function(items_1[i],items_1[j]) 
            # py[j,i]=py[i,j]
            # penelty_1=pxy[i,j]/py[i,j]
            # penelty_2=4*py[i,j]/pxy[i,j]
            ##print('distance betwwen label: '+str(py))
            #L[i, j] = penelty_1+penelty_2#2/(penelty_1+penelty_2)         #cov_function(items[i], items[j])
            L[i, j] = cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    # np.save('L.npy',L)
    # np.save('pxy.npy',pxy)
    # np.save('py.npy',py)  

    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    """
    exponential quadratic covariance function
    """
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f
    
def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X
class DANNIN():
    def __init__(self, args, train_dataset, val_dataset,base_res_dir, cuda):
        
        self.args= args
        self.train_dataset=train_dataset['data_loader']
        self.val_dataset=val_dataset['data_loader']
        
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        # self.domain_size= train_dataset['base_domain_size'] 
        # self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        self.phi= self.get_model()
        
        #self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)    
        
       
        self.train_acc=[]
        
 
        self.conditional = bool(self.args.conditional)
        self.class_balance = False        
        
        self.featurizer = self.phi.featurizer
        self.classifier = self.phi.classifier
        self.discriminator = self.phi.discriminator
        self.class_embeddings = self.phi.class_embeddings

        
        self.grad_penalty= self.args.grad_penalty
        self.lambda_= self.args.penalty_ws
        self.d_steps_per_g_step= self.args.d_steps_per_g_step
        self.initial_lr= 0.01
        
        # Optimizers
        self.disc_opt = torch.optim.SGD(
            (list(self.discriminator.parameters()) + 
                list(self.class_embeddings.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)

        self.gen_opt = torch.optim.SGD(
            (list(self.featurizer.parameters()) + 
                list(self.classifier.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)     
    def get_model(self):
        
        from models import net_for_dannin
        phi=net_for_dannin.get_model_for_DANNIN(self.args.model_name, self.args.out_classes,  
                        self.args.img_c, self.args.pre_trained, self.args.os_env)
        
        phi=phi.to(self.cuda) 
        return phi
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
    
    def get_test_accuracy(self, case):
        #import opacus
        #from opacus.grad_sample.grad_sample_module import GradSampleModule as gra
        
        # if self.args.dp_noise:
        #     opacus.autograd_grad_sample.disable_hooks()
            #self.privacy_engine.module.disable_hooks()
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
       
        dataset= self.val_dataset
    

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                # self.opt.zero_grad()
#                 print(x_e.shape)
#                 print(torch.cuda.memory_allocated())                
                x_e= x_e.to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)

                #Forward Pass
                z_e=self.phi.featurizer(x_e)
                out= self.phi.classifier(z_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == d_e ).item()
                test_size+= d_e.shape[0]
                
               
               

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0;
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_dann=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)
        
                all_x = x_e
                all_d = d_e
                all_z = self.featurizer(all_x)
                #print(all_x.shape) [64, 3, 224, 224]
                # print(all_z.shape)# [64, 512]
                if self.conditional:
                    disc_input = all_z + self.class_embeddings(all_d)
                else:
                    disc_input = all_z
                disc_out = self.discriminator(disc_input)
                #print(disc_out.shape) 64,6
                disc_labels = y_e
                # print('d_e') 64,1 [0,1,2,3,4,5]   
                if self.class_balance:
                    d_counts = F.one_hot(all_d).sum(dim=0)
                    weights = 1. / (d_counts[all_d] * d_counts.shape[0]).float()
                    disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
                    disc_loss = (weights * disc_loss).sum()
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels)
                    # print('disc_loss')#没变
                    # print(disc_loss)

                #Gen Loss
                all_preds = self.classifier(all_z)
                classifier_loss = F.cross_entropy(all_preds, all_d)
                gen_loss = (classifier_loss +
                            (self.lambda_ * -disc_loss)) ###modified

                penalty_erm += float(classifier_loss)
                penalty_dann += float(disc_loss)
                
                d_steps_per_g = self.d_steps_per_g_step
                if (epoch % (1+d_steps_per_g) < d_steps_per_g):
                    #print('disc_opt')
                    self.disc_opt.zero_grad()
                    disc_loss.backward()
                    #print(disc_loss)
                    self.disc_opt.step()
                else:
                    #print('gen_opt')
                    self.disc_opt.zero_grad()
                    self.gen_opt.zero_grad()
                    gen_loss.backward()
                    self.gen_opt.step()
                
                del classifier_loss
                del gen_loss 
                del disc_loss
                torch.cuda.empty_cache()
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == d_e ).item()
                train_size+= d_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_dann )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)  
            self.get_test_accuracy('val') 
          
        self.save_model()
    
       
#GPU
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 0, 'pin_memory': False} 
else:
    kwargs= {}
domains=[]
for i in range(20):
    dom_train=range(15,76)
    #random.sample((list(range(15,76))),55)
#dom_train=list(range(0,78))

#List of Train; Test domains
    train_domains= dom_train
# if args.dataset_name=='rot_mnist_spur':
#     dom=list(range(15,76,5))
#     color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
    

    test_domains= args.test_domains
    runId = datetime.datetime.now().isoformat().replace(':', '_')
    #Initialize


    res_dir= 'DIVERSE_DOMAINS/'



    base_res_dir=(res_dir + args.dataset_name )



    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir)    



    run=0
    print('get started')
            
    #DataLoader        
    train_dataset= get_dataloader_level_0( args, train_domains, 'train',  kwargs )    
    val_dataset= get_dataloader( args, 0, train_domains, 'val', 0, kwargs ) 


    train_method= DANNIN(
                            args, train_dataset, val_dataset,base_res_dir, 
                            cuda
                            )

    #Train the method: It will save the model's weights post training and evalute it on test accuracy
    train_method.train()



    total_domain=list(range(15,76))######改成15-75 ###用于计算每个domain的location



#phi=train_method.phi



    batch_size=2000
    location_dom_num=[]##记录每个domain的location，用于生成相似度矩阵L
    location_dom_denom=[]

    for domain in total_domain:
        
        

        data_dir= 'data/datasets'+'/mnist/' + args.dataset_name + '_' + args.mnist_case + '/'

        if args.dataset_name=='rot_mnist_spur':
            dom=list(range(15,76,5))
            color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
            degree_idx=int(domain/6)
            col_idx=domain%6
            degree=dom[degree_idx]
            color=color_list[col_idx]
            load_dir= data_dir + 'train' + '/' + 'seed_' + str(0) + '_domain_' + str(degree)+'_color_'+color###########需要加入颜色项
        else:
            load_dir= data_dir + 'train' + '/' + 'seed_' + str(0) + '_domain_' + str(domain)###########################

                
        mnist_imgs= torch.load( load_dir +  '_org_data.pt')
        mnist_labels= torch.load( load_dir +  '_label.pt')
        y = torch.eye(10)
        mnist_labels = y[mnist_labels]
        data_domain_obj=get_each_domain_data( mnist_imgs, mnist_labels)
        data_domain=data_utils.DataLoader(data_domain_obj, batch_size=batch_size,  **kwargs )
        mean_global=[]
        for batch_idx, (x_e, y_e ) in enumerate(data_domain):############把数据输入到模型里
        
            with torch.no_grad():
                train_method.phi=train_method.phi.to("cpu")
                #x_e= x_e.to(cuda)
                #print(x_e.shape())

            #Forward Pass
                
                out_feat=train_method.phi.featurizer(x_e)
                out_dom=train_method.phi.classifier(out_feat)
                out_feat=out_feat.cpu().numpy()
                #print(out_feat.shape)
                out_dom=out_dom.cpu().numpy()
                #print(out_dom.shape)
                out=np.concatenate((out_feat,out_dom),1)
                #print(out.shape)
                mean_out=np.mean(out,0)
                #print(mean_out.shape)

                #加入label的因素
                # y_e=y_e.numpy()
                # y_e=np.mean(y_e,0)


                #print(y_e.shape)
                #print(out.shape)                         
                # mean_v_ten=torch.mean(out_feat,0)   ######## 
                # mean_v_num=mean_v_ten.cpu().numpy()
                # mean_d_ten=torch.mean(out_dom,0)
                # mean_d_num=mean_d_ten.cpu().numpy()
                # mean_out=np.concatenate((mean_v_num,mean_d_num),1)
                # print(mean_out.shape)

                #分子分母
                location_dom_num.append(mean_out)
                #location_dom_denom.append(y_e)


        ##当前domain的location
    

    L=build_similary_matrix(get_similar_cov(4),location_dom_num)#,location_dom_denom)
    X=sample_k(location_dom_num,L,5) ##选取6个差异最大的domain

    domain_name=[]

    for i in range(0,len(X)):
        if X[i]==True:
            domain_name.append(i+15)
    domains.append(domain_name)
    print(domains)
    np.save('domain.npy',domains)  
            
    
