
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
import pickle

#Sklearn
import sklearn
from sklearn.manifold import TSNE

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#robustdg
from utils.helper import *
from utils.match_function import *

import random
import torchvision
import datetime
#python level-2-mmd-coral.py
dom_test=[0,90] #rotmnistspur
#dom_test=[0,90]
#[-1]
#["0"]   



#python level-2-mmd-coral.py



# Input Parsing
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='fashion_mnist', #'fashion_mnist'
                    help='Datasets: rot_mnist; fashion_mnist; pacs')
parser.add_argument('--method_name', type=str, default='erm_match', #defalt=erm_match
                    help=' Training Algorithm: erm_match; matchdg_ctr; matchdg_erm')
parser.add_argument('--model_name', type=str, default='resnet18', 
                    help='Architecture of the model to be trained')
parser.add_argument('--train_domains', nargs='+', type=str, default=[0], #need to be modified
                    help='List of train domains')
parser.add_argument('--test_domains', nargs='+', type=str, default=dom_test, #to be modified
                    help='List of test domains')
parser.add_argument('--out_classes', type=int, default=10, 
                    help='Total number of classes in the dataset')
parser.add_argument('--img_c', type=int, default= 1, #used to be 1
                    help='Number of channels of the image in dataset')
parser.add_argument('--img_h', type=int, default= 224, 
                    help='Height of the image in dataset')
parser.add_argument('--img_w', type=int, default= 224, 
                    help='Width of the image in dataset')
parser.add_argument('--fc_layer', type=int, default= 1, 
                    help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet')#fc layer for classification
parser.add_argument('--match_layer', type=str, default='logit_match', 
                    help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--pos_metric', type=str, default='l2', 
                    help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos')
parser.add_argument('--rep_dim', type=int, default=250, 
                    help='Representation dimension for contrsative learning')
parser.add_argument('--pre_trained',type=int, default=0, 
                    help='0: No Pretrained Architecture; 1: Pretrained Architecture')
parser.add_argument('--perfect_match', type=int, default=1, 
                    help='0: No perfect match known (PACS); 1: perfect match known (MNIST)')
parser.add_argument('--opt', type=str, default='sgd', 
                    help='Optimizer Choice: sgd; adam') 
parser.add_argument('--weight_decay', type=float, default=5e-4,
                   help='Weight Decay in SGD')
parser.add_argument('--lr', type=float, default=0.01, 
                    help='Learning rate for training the model')
parser.add_argument('--batch_size', type=int, default=16, 
                    help='Batch size foe training the model')
parser.add_argument('--epochs', type=int, default=50, ###30
                    help='Total number of epochs for training the model')
parser.add_argument('--penalty_s', type=int, default=-1, 
                    help='Epoch threshold over which Matching Loss to be optimised')
parser.add_argument('--penalty_irm', type=float, default=0.0, 
                    help='Penalty weight for IRM invariant classifier loss')
parser.add_argument('--penalty_aug', type=float, default=1.0, 
                    help='Penalty weight for Augmentation in Hybrid approach loss')
parser.add_argument('--penalty_ws', type=float, default=0.1, 
                    help='Penalty weight for Matching Loss')
parser.add_argument('--penalty_diff_ctr',type=float, default=1.0, 
                    help='Penalty weight for Contrastive Loss')
parser.add_argument('--tau', type=float, default=0.05, 
                    help='Temperature hyper param for NTXent contrastive loss ')
parser.add_argument('--match_flag', type=int, default=0, 
                    help='0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--match_case', type=float, default=1.0, 
                    help='0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--match_interrupt', type=int, default=5, 
                    help='Number of epochs before inferring the match strategy')
parser.add_argument('--ctr_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--match_abl', type=int, default=0, 
                    help='0: Randomization til class level ; 1: Randomization completely')
parser.add_argument('--n_runs', type=int, default=1, ###defalt 3
                    help='Number of iterations to repeat the training process')
parser.add_argument('--n_runs_matchdg_erm', type=int, default=1, 
                    help='Number of iterations to repeat training process for matchdg_erm')
parser.add_argument('--ctr_model_name', type=str, default='resnet18', 
                    help='(For matchdg_ctr phase) Architecture of the model to be trained')
parser.add_argument('--ctr_match_layer', type=str, default='logit_match', 
                    help='(For matchdg_ctr phase) rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level')
parser.add_argument('--ctr_match_flag', type=int, default=1, 
                    help='(For matchdg_ctr phase) 0: No Update to Match Strategy; 1: Updates to Match Strategy')
parser.add_argument('--ctr_match_case', type=float, default=0.01, 
                    help='(For matchdg_ctr phase) 0: Random Match; 1: Perfect Match. 0.x" x% correct Match')
parser.add_argument('--ctr_match_interrupt', type=int, default=5, 
                    help='(For matchdg_ctr phase) Number of epochs before inferring the match strategy')
parser.add_argument('--mnist_seed', type=int, default=0, 
                    help='Change it between 0-6 for different subsets of Mnist and Fashion Mnist dataset')
parser.add_argument('--retain', type=float, default=0, 
                    help='0: Train from scratch in MatchDG Phase 2; 2: Finetune from MatchDG Phase 1 in MatchDG is Phase 2')
parser.add_argument('--cuda_device', type=int, default=2, 
                    help='Select the cuda device by id among the avaliable devices' )
parser.add_argument('--os_env', type=int, default=0, 
                    help='0: Code execution on local server/machine; 1: Code execution in docker/clusters' )


#Differential Privacy
parser.add_argument('--dp_noise', type=int, default=0, 
                    help='0: No DP noise; 1: Add DP noise')
parser.add_argument('--dp_epsilon', type=float, default=1.0, 
                    help='Epsilon value for Differential Privacy')
# Special case when you want to check results with the dp setting for the infinite epsilon case
parser.add_argument('--dp_attach_opt', type=int, default=1, 
                    help='0: Infinite Epsilon; 1: Finite Epsilion')


#MMD, DANN
parser.add_argument('--d_steps_per_g_step', type=int, default=1)
parser.add_argument('--grad_penalty', type=float, default=0.0)
parser.add_argument('--conditional', type=int, default=0)
parser.add_argument('--gaussian', type=int, default=1)

#fish
parser.add_argument('--meta_lr', type=float, default=0.01)
parser.add_argument('--meta_steps', type=int, default=5)


#Slab Dataset
parser.add_argument('--slab_data_dim', type=int, default= 2, 
                    help='Number of features in the slab dataset')
parser.add_argument('--slab_total_slabs', type=int, default=7)
parser.add_argument('--slab_num_samples', type=int, default=1000)
parser.add_argument('--slab_noise', type=float, default=0.1)


#Differentiate between resnet, lenet, domainbed cases of mnist
parser.add_argument('--mnist_case', type=str, default='resnet18', 
                    help='MNIST Dataset Case: resnet18; lenet, domainbed')
parser.add_argument('--mnist_aug', type=int, default=0, 
                    help='MNIST Data Augmentation: 0 (MNIST, FMNIST Privacy Evaluation); 1 (FMNIST)')


#Multiple random matches
parser.add_argument('--total_matches_per_point', type=int, default=1, 
                    help='Multiple random matches')


# Evaluation specific
parser.add_argument('--test_metric', type=str, default='match_score', 
                    help='Evaluation Metrics: acc; match_score, t_sne, mia')
parser.add_argument('--acc_data_case', type=str, default='test', 
                    help='Dataset Train/Val/Test for the accuracy evaluation metric')
parser.add_argument('--top_k', type=int, default=10, 
                    help='Top K matches to consider for the match score evaluation metric')
parser.add_argument('--match_func_aug_case', type=int, default=0, 
                    help='0: Evaluate match func on train domains; 1: Evaluate match func on self augmentations')
parser.add_argument('--match_func_data_case', type=str, default='val', 
                    help='Dataset Train/Val/Test for the match score evaluation metric')

args = parser.parse_args()



class BaseDataLoader(data_utils.Dataset):
    def __init__(self, args, list_domains, root, transform=None, data_case='train'):
        self.args= args
        self.list_domains = list_domains
        
        self.root = 'data/datasets' + root
        self.transform = transform
        self.data_case = data_case
        
        
        self.base_domain_size= 0
        self.list_size=[]
        self.data= [] 
        self.labels= [] 
        self.domains= [] 
        self.indices= [] 
        self.objects= []

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        d = self.domains[index]
        idx = self.indices[index]
        objs= self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, y, d, idx, objs

    def get_size(self):
        return self.labels.shape[0]
     

class MnistRotated(BaseDataLoader):
    def __init__(self, args, list_domains, mnist_subset, root, transform=None, data_case='train',  download=True):
        
        super().__init__(args, list_domains, root, transform, data_case) 
        self.mnist_subset = mnist_subset
        self.download = download
        
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()#####################

    def _get_data(self):
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        data_dir= self.root + self.args.dataset_name + '_' + self.args.mnist_case + '/'           
        
        
        dom=list(range(15,76,5))
        color_list=['red', 'blue', 'green', 'brown', 'pink','yellow']
        

        for domain in self.list_domains:

            if self.args.dataset_name=='rot_mnist_spur':
                if domain<0:
                    degree=0
                    color='white'
                elif domain>77:
                    degree=90
                    color='white'

                else:
                    degree_idx=int(domain/6)
                    col_idx=domain%6
                    degree=dom[degree_idx]
                    color=color_list[col_idx]
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(degree)+'_color_'+color###########需要加入颜色项
            
            else:
                load_dir= data_dir + self.data_case + '/' + 'seed_' + str(self.mnist_subset) + '_domain_' + str(domain)
         
            
            mnist_imgs= torch.load( load_dir +  '_org_data.pt')
            mnist_labels= torch.load( load_dir +  '_label.pt')
            mnist_idx= list(range(len(mnist_imgs)))
            
            #print('Source Domain ', domain)
            list_img.append(mnist_imgs)
            list_labels.append(mnist_labels)
            list_idx.append(mnist_idx)
            list_size.append(mnist_imgs.shape[0])

       
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels) ####把domains的数据合并
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)###合并，hstack水平合并
        self.training_list_size= list_size
        
        #Rotated MNIST the objects are same the data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels####################################################################得到对应的domain label
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size        
        
        # Shuffle everything one more time
        # inds = np.arange(data_labels.size()[0])
        # np.random.shuffle(inds)
        # data_imgs = data_imgs[inds]
        # data_labels = data_labels[inds]
        # data_domains = data_domains[inds].long()
        # data_indices = data_indices[inds]
        # data_objects = data_objects[inds]

        # Convert to onehot
        y = torch.eye(10)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains.long()]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)        
        
        #print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects##########################


def get_dataloader_level_1(args, domains, data_case, kwargs):
    
    dataset={}
    batch_size=50#50
        
         
            
    #print('MNIST Subset: ', mnist_subset)
    data_obj=  MnistRotated(args, domains, 0, '/mnist/', data_case=data_case)###rot_mnist_spur的这个要改
    
    dataset['data_loader']= data_utils.DataLoader(data_obj, batch_size=batch_size, shuffle=True, **kwargs )
    
    dataset['data_obj']= data_obj
    dataset['total_domains']= len(domains)
    dataset['domain_list']= domains
    dataset['base_domain_size']= data_obj.base_domain_size       
    dataset['domain_size_list']= data_obj.training_list_size    
    
    #print(data_case, data_obj.base_domain_size, data_obj.training_list_size)
    return dataset
class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x
class Model_erm(nn.Module):
    def __init__(self,feat,fc):
        super(Model_erm, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output
class BaseAlgo():
    def __init__(self, args, train_dataset,base_res_dir, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        
        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)
          
        self.train_acc=[]
        
    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
        
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        # elif self.args.opt == 'adam':
        #     opt= optim.Adam([
        #                 {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
        #         ], lr= self.args.lr)


        # opt=optim.Adam(
        #     self.phi.parameters(),
        #     lr=0.001,
        #     weight_decay=0.0
        # )
        
        return opt

class Erm(BaseAlgo):
    def __init__(self, args, train_dataset,  base_res_dir,  cuda):
        
        super().__init__(args, train_dataset, base_res_dir, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0        
        for epoch in range(self.args.epochs):   
            
            
            penalty_erm=0
            train_acc= 0.0
            train_size=0
    
            for batch_idx, (x_e, y_e ,d_e, idx_e,obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                
                #Forward Pass
                out= self.phi(x_e)
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
        # Save the model's weights post training
        #self.save_model()  






def build_similary_matrix(cov_function, items):
    """
    build the similarity matrix 
    """
    L = np.zeros((len(items), len(items)))
    for i in range(len(items)):
        for j in range(i, len(items)):
            L[i, j] = cov_function(items[i],items[j])             #cov_function(items[i], items[j])
            L[j, i] = L[i, j]
    return L

def get_similar_cov(inx):
    if inx==1:
        return get_l1_similar
    elif inx==2:
        return get_l2_similar
    elif inx==3:
        return get_cos_similar
    else:
        return exp_quadratic()

def get_l2_similar(v1,v2):
    return np.linalg.norm(v1-v2)

def get_l1_similar(v1,v2):
    return np.linalg.norm(v1-v2,ord=1)

def get_cos_similar(v1, v2):
    num = float(np.dot(v1, v2))  # 向量点乘
    denom = np.linalg.norm(v1) * np.linalg.norm(v2)  # 求模长的乘积
    return 0.5 + 0.5 * (num / denom) if denom != 0 else 0

def exp_quadratic(sigma=0.1):
    """
    exponential quadratic covariance function
    """
    def f(p1, p2):
        return np.exp(-0.5 * (((p1 - p2)**2).sum()) / sigma**2)
    return f
    
def sample_k(items, L, k, max_nb_iterations=1000, rng=np.random):
    """
    Sample a list of k items from a DPP defined
    by the similarity matrix L.
    """
    initial = rng.choice(range(len(items)), size=k, replace=False)
    X = [False] * len(items)
    for i in initial:
        X[i] = True
    X = np.array(X)
    for i in range(max_nb_iterations):
        u = rng.choice(np.arange(len(items))[X])
        v = rng.choice(np.arange(len(items))[~X])
        Y = X.copy()
        Y[u] = False
        L_Y = L[Y, :]
        L_Y = L_Y[:, Y]
        L_Y_inv = np.linalg.inv(L_Y)

        c_v = L[v:v+1, :]
        c_v = c_v[:, v:v+1]
        b_v = L[Y, :]
        b_v = b_v[:, v:v+1]
        c_u = L[u:u+1, :]
        c_u = c_u[:, u:u+1]
        b_u = L[Y, :]
        b_u = b_u[:, u:u+1]

        p = min(1, c_v - np.dot(np.dot(b_v.T, L_Y_inv), b_v) /
                (c_u - np.dot(np.dot(b_u.T, L_Y_inv.T), b_u)))
        if rng.uniform() <= p:
            X = Y[:]
            X[v] = True
    return X

runId = datetime.datetime.now().isoformat().replace(':', '_')
cuda= torch.device("cuda:" + str(args.cuda_device))
if cuda:
    kwargs = {'num_workers': 2, 'pin_memory': False} 
else:
    kwargs= {}
final_accuracy_target_val=[]
final_accuracy_source_val=[]

domains_l1_rm=[ [17, 39, 51, 62, 68],[15, 26, 37, 42, 70], [24, 33, 38, 42, 50], [32, 39, 44, 60, 65], [16, 25, 34, 37, 67], [17, 33, 41, 54, 61], [22, 33, 35, 53, 71], [26, 38, 59, 60, 70],[22, 28, 35, 45, 70], [15, 16, 22, 33, 69], [25, 47, 64, 68, 73],[40, 41, 50, 63, 74], [16, 19, 23, 47, 65],[28, 40, 49, 53, 55], [20, 54, 55, 64, 70], [25, 35, 39, 60, 68],[24, 45, 48, 52, 75], [19, 31, 35, 54, 73], [37, 49, 61, 68, 74], [22, 32, 33, 72, 75]]
#
domains_l1_fm=[ [15, 16, 18, 24, 57], [29, 33, 59, 69, 70], [17, 18, 23, 46, 67], [26, 33, 41, 54, 73], [22, 26, 30, 35, 43], [20, 33, 62, 67, 74], [16, 42, 51, 53, 74],[22, 52, 58, 64, 65], [28, 43, 54, 64, 72], [16, 35, 45, 68, 72], [19, 21, 39, 57, 68], [25, 38, 39, 61, 65],[25, 38, 53, 65, 70],[17, 22, 37, 43, 46],[15, 20, 31, 45, 50], [25, 33, 38, 47, 50], [20, 37, 38, 64, 66], [17, 38, 47, 61, 69], [16, 29, 53, 55, 74],[24, 51, 62, 67, 68]]#[42, 46, 57, 63, 74],[49, 50, 55, 57, 60]本来22个去掉两个和rm一样20个
#/221019056/rdd/indecies/[15, 16, 18, 24, 57]/indecies.npy
for domains in domains_l1_fm:

    print(domains)

    res_dir= './rdd/indecies_150_no_shuffle_fm/'

    base_res_dir=(
            res_dir  + #args.method_name + '/' + 'train_' + 
            str(domains)+'/'
        )
    if not os.path.exists(base_res_dir):
        os.makedirs(base_res_dir) 


    #start select indecies
    dataset1=get_dataloader_level_1(args, domains, 'train', kwargs)

    train_method= Erm(args, dataset1,  base_res_dir, cuda)

    train_method.train()

    data_loc=[]
    for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset1['data_loader']):
    # for data in data_imgs:
        with torch.no_grad():
            x_e= x_e.to(cuda)
            location=train_method.phi.featurizer(x_e)                
            location=torch.mean(location,0) 
            location=location.cpu().numpy()
        data_loc.append(location)
    print(batch_idx)
    L=build_similary_matrix(exp_quadratic(0.1),data_loc)#160取100
    X=sample_k(data_loc,L,150) ##100 43.65  #80 for colored #65 for others

    print(X)

    # indecies=[]
    # for i in range(0,len(X)):
    #             if X[i]==True:
    #                 indecies.append(i)
    np.save(base_res_dir+'indecies.npy',X)
