#Common imports
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle

#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#robustdg
from utils.helper import *
from utils.match_function import *

import random
import torchvision
import datetime
#python level-2-mmd-coral.py
dom_test=[0,90] #rotmnistspur
#dom_test=[0,90]
#[-1]
#["0"]   



#python level-2-mmd-coral.py



# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='rot_mnist', #'fashion_mnist'
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16, 
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=30, ###30
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0, 
                    help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0, 
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0, 
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, 
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=4, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=0)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()



class BaseDataLoader(data_utils.Dataset):
    def __init__(self, args, list_domains, root, transform=None, data_case='train'):
        self.args= args
        self.list_domains = list_domains
        
        self.root = 'data/datasets' + root
        self.transform = transform
        self.data_case = data_case
        
        
        self.base_domain_size= 0
        self.list_size=[]
        self.data= [] 
        self.labels= [] 
        self.domains= [] 
        self.indices= [] 
        self.objects= []

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        d = self.domains[index]
        idx = self.indices[index]
        objs= self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, y, d, idx, objs

    def get_size(self):
        return self.labels.shape[0]
     

class MnistRotated(BaseDataLoader):
    def __init__(self, args, list_domains, mnist_subset, root, transform=None, data_case='train',  download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()#####################

    def _get_data(self):
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        
        dom=list(range(15,76,5))
        color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
        

        for domain in self.list_domains:

            if self.args.dataset_name=='rot_mnist_spur':
                if domain<0:
                    degree=0
                    color='white'
                elif domain>77:
                    degree=90
                    color='white'

                else:
                    degree_idx=int(domain/6)
                    col_idx=domain%6
                    degree=dom[degree_idx]
                    color=color_list[col_idx]
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(degree)+'_color_'+color###########需要加入颜色项
            
            else:
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
         
            
            mnist_imgs= torch.load( load_dir +  '_org_data.pt')
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_idx= list(range(len(mnist_imgs)))
            
            #print('Source Domain ', domain)
            list_img.append(mnist_imgs)
            list_labels.append(mnist_labels)
            list_idx.append(mnist_idx)
            list_size.append(mnist_imgs.shape[0])

       
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels) ####把domains的数据合并
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        
        
        # Shuffle everything one more time
        # inds = np.arange(data_labels.size()[0])
        # np.random.shuffle(inds)
        # data_imgs = data_imgs[inds]
        # data_labels = data_labels[inds]
        # data_domains = data_domains[inds].long()
        # data_indices = data_indices[inds]
        # data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains.long()]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects##########################


def get_dataloader_level_1(args, domains, data_case, kwargs):
    
    dataset={}
    batch_size=50
        
         
            
    #print('MNIST Subset: ', mnist_subset)#####'/matchdg/mnist/'
    data_obj=  MnistRotated(args, domains, 0, '/mnist/', data_case=data_case)###rot_mnist_spur的这个要改
    
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=True, **kwargs )
    
    dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    dataset['base_domain_size']= data_obj.base_domain_size       
    dataset['domain_size_list']= data_obj.training_list_size    
    
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    return dataset
class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x
class Model_erm(nn.Module):
    def __init__(self,feat,fc):
        super(Model_erm, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output
class BaseAlgo():
    def __init__(self, args, train_dataset,base_res_dir, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)
          
        self.train_acc=[]
        
    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=nn.Sequential(feat,fc)#Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
        
    def get_opt(self):
        # if self.args.opt == 'sgd':
        #     opt= optim.SGD([
        #                  {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
        #         ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        # elif self.args.opt == 'adam':
        #     opt= optim.Adam([
        #                 {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
        #         ], lr= self.args.lr)
        opt=optim.Adam(
            self.phi.parameters(),
            lr=0.001,
            weight_decay=0.0
        )
        
        return opt

class Erm(BaseAlgo):
    def __init__(self, args, train_dataset,  base_res_dir,  cuda):
        
        super().__init__(args, train_dataset, base_res_dir, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0        
        for epoch in range(self.args.epochs):   
            
            
            penalty_erm=0
            train_acc= 0.0
            train_size=0
    
            for batch_idx, (x_e, y_e ,d_e, idx_e,obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                
                #Forward Pass
                out= self.phi(x_e)
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
        # Save the model's weights post training
        self.save_model()  


class MMD():####这里有mmd和coral，Gaussian_kernel true 就是mmd,false 就是coral,两个都是matching类型的
    def __init__(self, args, train_dataset, indecies, val_dataset, test_dataset, base_res_dir,  cuda):
        
        
        self.args= args
        self.train_dataset=train_dataset['data_loader']
        self.val_dataset= val_dataset['data_loader']
        self.test_dataset= test_dataset['data_loader']
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        self.indecies=indecies
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]

        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.mmd_gamma= 1.0#self.args.penalty_ws
        self.gaussian= bool(self.args.gaussian)
        self.conditional= bool(self.args.conditional)
        if self.gaussian: 
            self.kernel_type = "gaussian"
        else:
            self.kernel_type = "mean_cov"
        print(self.kernel_type)
        #enc=nn.Sequential(*list(self.phi.children())[:-1])
        self.featurizer = self.phi.featurizer
        self.classifier = self.phi.classifier
        

    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
       #mmd 
    # def get_opt(self):
    #     if self.args.opt == 'sgd':
    #         opt= optim.SGD([
    #                      {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
    #             ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
    #     elif self.args.opt == 'adam':
    #         opt= optim.Adam([
    #                     {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
    #             ], lr= self.args.lr)
        
    #     return opt

        #coral
    def get_opt(self):
        # if self.args.opt == 'sgd':
        # opt= optim.SGD([
        #                 {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
        #     ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        # elif self.args.opt == 'adam':
        opt= optim.Adam(
            self.phi.parameters(),
            lr=0.001,
            weight_decay=0.0
        )
        return opt
    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)
    
    def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                           1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))

        return K

     #mmd

    # def mmd(self, x, y):
    #     if self.kernel_type == "gaussian":
    #         Kxx = self.gaussian_kernel(x, x).mean()
    #         Kyy = self.gaussian_kernel(y, y).mean()
    #         Kxy = self.gaussian_kernel(x, y).mean()
    #         return Kxx + Kyy - 2 * Kxy
    #     else:
    #         print(self.kernel_type)
    #         mean_x = x.mean(0, keepdim=True)
    #         mean_y = y.mean(0, keepdim=True)
    #         cent_x = x - mean_x
    #         cent_y = y - mean_y
    #         cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
    #         cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

    #         mean_diff = (mean_x - mean_y).pow(2).mean()
    #         cova_diff = (cova_x - cova_y).pow(2).mean()

    #         return mean_diff + cova_diff
    
    # def mmd_regularization(self, features, d, nmb):
    #     penalty= torch.tensor(0.0).to(self.cuda)
    #     for d_i in range(nmb):
    #         for d_j in range(d_i + 1, nmb):
    #             f_i= features[ d == d_i ]
    #             f_j= features[ d == d_j ]
    #             penalty += self.mmd(f_i, f_j)
    #     return penalty

    #coral
    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            #print('mean_diff + cova_diff')
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            mean_diff = (mean_x - mean_y)**2#.abs().pow(2).mean()
            mean_diff=mean_diff.mean()
            if len(x)==1 or len(y)==1:
                return mean_diff
            else:
                if len(x)==0  or len(y) ==0:
                    print(len(x))
                    print(len(y))
                    print('length problem!!!')
                cent_y = y - mean_y
                cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
                cent_x = x - mean_x
                cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
                cova_diff = (cova_x - cova_y)**2#.abs().pow(2).mean()
                cova_diff=cova_diff.mean()
                return mean_diff + cova_diff
            
            
    
    def mmd_regularization(self, features, d, nmb):###不应该用len(match_domains)应该直接用match——domains
        penalty= torch.tensor(0.0).to(self.cuda)
        l=len(nmb)
        for i in range(l):
            for j in range(i + 1, l):
                f_i= features[ d == nmb[i] ]
                f_j= features[ d == nmb[j] ]
                penalty += self.mmd(f_i, f_j)
            # f_i= features[ d == d_i ]
            # if len(f_i) == 0 :
            #     continue
            # for d_j in range(d_i + 1, l):
            #     f_j= features[ d == d_j ]
            #     if  len(f_j) == 0:
            #         continue

            #     penalty += self.mmd(f_i, f_j)
                # print('penalty')#第二轮开始全是nan
                # print(penalty)
        
        return penalty

    def get_test_accuracy(self, case):
           
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.phi(x_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size       

    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_mmd=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)
                if self.indecies[batch_idx] == 0 :
                    # print('false ',batch_idx)
                    continue

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1)
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                
                #ERM
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)
                
                #MMD
                mmd_loss=torch.tensor(0.0).to(self.cuda)
                match_domains= torch.unique(d_e)
                class_labels= torch.unique(y_e)
                #mmd
                #nmb = len(match_domains)
                #coral
                nmb=match_domains.numpy().tolist()

                if self.conditional:
                    for y_c in range(len(class_labels)):                    
                        features_c= features[ y_e == y_c ]
                        d_c= d_e[ y_e == y_c ]
                        if len(torch.unique(d_c)) != nmb:
                            print('*********************************')
                            print('Error: Some classes not distributed across all the domains; issues for class conditional methods')
                            continue
                        mmd_loss+= self.mmd_regularization(features_c, d_c, nmb)
                else:
                    mmd_loss+= self.mmd_regularization(features, d_e, nmb)            

                if len(nmb) > 1:
                    mmd_loss /= (len(nmb) * (len(nmb) - 1) / 2)
                # print('mmd_loss')
                # print(mmd_loss)             
                penalty_mmd+= float(mmd_loss)

                # nmb = len(match_domains)

                # if self.conditional:
                #     for y_c in range(len(class_labels)):                    
                #         features_c= features[ y_e == y_c ]
                #         d_c= d_e[ y_e == y_c ]
                #         if len(torch.unique(d_c)) != nmb:
                #             print('*********************************')
                #             print('Error: Some classes not distributed across all the domains; issues for class conditional methods')
                #             continue
                #         mmd_loss+= self.mmd_regularization(features_c, d_c, nmb)
                # else:
                #     mmd_loss+= self.mmd_regularization(features, d_e, nmb)            

                # if nmb > 1:
                #     mmd_loss /= (nmb * (nmb - 1) / 2)
                                
                # penalty_mmd+= float(mmd_loss)
                
                #Backward Pass
                loss_e+= self.mmd_gamma*mmd_loss                
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del mmd_loss 
                del loss_e
                torch.cuda.empty_cache()
                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_mmd )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])



def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            L[i, j] = cov_function(items[i],items[j])             #cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    """
    exponential quadratic covariance function
    """
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f
    
def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X

runId = datetime.datetime.now().isoformat().replace(':', '_')
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 1, 'pin_memory': False} 
else:
    kwargs= {}
final_accuracy_target_val=[]
final_accuracy_source_val=[]
#fashion mnist
                
domains_l1_rm=[ [17, 39, 51, 62, 68],[15, 26, 37, 42, 70],[24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]

domains_l1_fm=[[15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72],  [19, 21, 39, 57, 68], [25, 38, 39, 61, 65],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]#[42, 46, 57, 63, 74],[49, 50, 55, 57, 60]本来22个去掉两个和rm一样20个
#刚刚跑完的部分，记得添回去[15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], 
domains_list=[[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61]]
#[[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65], [28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]#[[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70]]  
for domains in domains_l1_rm:
#[[17, 33, 41, 54, 61], [24, 45, 48, 52, 75], [25, 35, 39, 60, 68], [21, 29, 32, 36, 38], [26, 38, 59, 60, 70], [15, 18, 35, 53, 60], [22, 33, 35, 53, 71], [21, 24, 48, 49, 57], [37, 49, 61, 68, 74], [40, 41, 50, 63, 74], [15, 36, 38, 46, 56], [19, 31, 35, 54, 73], [21, 30, 41, 42, 44], [46, 49, 58, 62, 75], [15, 16, 22, 33, 69], [22, 23, 33, 41, 65], [27, 29, 35, 39, 59], [28, 43, 50, 64, 68], [22, 28, 35, 45, 70], [16, 18, 33, 37, 44]]:


    print(domains)
    res_dir= 'results_coral/l2/'

    base_res_dir=(
            res_dir + args.dataset_name + '_180_2.0/' + #args.method_name + '/' + 'train_' + 
            str(domains)
        )
    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir) 


    #start select indecies


    # train_method= Erm(args, dataset1,  base_res_dir, cuda)

    # train_method.train()

    # data_loc=[]
    # for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset1['data_loader']):
    # # for data in data_imgs:
    #     with torch.no_grad():
    #         x_e= x_e.to(cuda)
    #         location=train_method.phi.featurizer(x_e)                
    #         location=torch.mean(location,0) 
    #         location=location.cpu().numpy()
    #     data_loc.append(location)
    # print(batch_idx)
    # L=build_similary_matrix(exp_quadratic(0.1),data_loc)#160取100
    # X=sample_k(data_loc,L,130) ##100 43.65  #80 for colored #65 for others

    #print(X)

    # indecies=[]
    # for i in range(0,len(X)):
    #             if X[i]==True:
    #                 indecies.append(i)


    #finish indecies


    #dataset2=dataset1[['data_loader']][X]
    dataset1=get_dataloader_level_1(args, domains, 'train', kwargs)
    val_dataset= get_dataloader( args, 0, domains, 'val', 0, kwargs ) 
        # if args.method_name=='dannin':
        #     test_dataset= get_dataloader( args, (run+1)%3, train_domains, 'val', 0, kwargs )  
        # else:
    test_dataset= get_dataloader( args, 0,args.test_domains, 'test', 0, kwargs )  
    base_load_index_dir='/home/skyfall/rdd/rdd/indecies_180_no_shuffle_rm'
    #residual_dir='/'+str(domains)+'indecies.npy'
    residual_dir='/'+str(domains)+'/indecies.npy'
    indecies=np.load(base_load_index_dir+residual_dir)
    count=0
    for i in indecies:
        count+= i==1
    print(count)
    train_method= MMD(
                            args, dataset1,indecies, val_dataset,
                            test_dataset, base_res_dir, 
                            cuda
                            )
    train_method.train()  
    np.save( base_res_dir + '/Val_Acc' + '.npy', np.array(train_method.val_acc) )
    np.save( base_res_dir + '/Test_Acc' +  '.npy', np.array(train_method.final_acc)) 
    if args.method_name != 'matchdg_ctr':
            final_acc= np.max(train_method.final_acc)
            final_accuracy_target_val.append(final_acc)
            print(final_accuracy_target_val)
            
            idx= np.argmax(train_method.val_acc)
            final_acc= train_method.final_acc[idx]
            final_accuracy_source_val.append(final_acc)
            print(final_accuracy_source_val)
            print('final-acc')
            print(train_method.final_acc)
            print('val acc')
            print(train_method.val_acc)



print('\n')
print('Done for the Model..')
print('Final Test Accuracy (Source Validation)', np.mean(final_accuracy_source_val), np.std(final_accuracy_source_val) )
print('Final Test Accuracy (Target Validation)', np.mean(final_accuracy_target_val), np.std(final_accuracy_target_val) )
print('\n')
    
        



