import sys
import numpy as np
import argparse
import copy
import random
import json
import time

import torch
import torchvision
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
from torch import autograd###nan 检测

class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x
class Model_erm(nn.Module):
    def __init__(self,feat,fc):
        super(Model_erm, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output
class MMD():####这里有mmd和coral，Gaussian_kernel true 就是mmd,false 就是coral,两个都是matching类型的
    def __init__(self, args, train_dataset,  val_dataset, test_dataset, base_res_dir,  cuda):
        
        
        self.args= args
        self.train_dataset=train_dataset['data_loader']
        self.val_dataset= val_dataset['data_loader']
        self.test_dataset= test_dataset['data_loader']
        
        self.train_domains= train_dataset['domain_list']#####rot_mnist_spur的这里是字典
        self.total_domains= train_dataset['total_domains']#####len(domains)
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= 0
        self.cuda= cuda
        
        
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]

        self.phi= self.get_model()
        self.opt= self.get_opt()######
        self.mmd_gamma=1.0#1.0 #self.args.penalty_ws#要不要改成1.0
        self.gaussian= bool(self.args.gaussian)
        self.conditional= bool(self.args.conditional)
        if self.gaussian: 
            self.kernel_type = "gaussian"
        else:
            self.kernel_type = "mean_cov"
            print('kernel_type')
            print(self.kernel_type)

        
        #enc=nn.Sequential(*list(self.phi.children())[:-1])
        self.featurizer = self.phi.featurizer
        self.classifier = self.phi.classifier
        

    def get_resnet(self,classes, fc_layer, num_ch, pre_trained):    
        
            
        model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
            
            
        if fc_layer:
            model.fc = nn.Linear(n_inputs, n_outputs)
        else:
            print('Here')
            model.fc = Identity(n_inputs)
   
            
        if num_ch==1:
            model.conv1 = nn.Conv2d(1, 64, 
                                    kernel_size=(7, 7), 
                                    stride=(2, 2), 
                                    padding=(3, 3), 
                            bias=False)
        return model
    
    
    
    def get_model(self):
        
        resnet_enc=self.get_resnet(self.args.out_classes,0, self.args.img_c, 0)
        feat=resnet_enc
        resnet_fc=self.get_resnet(self.args.out_classes,1, self.args.img_c, 0)
        fc=resnet_fc.fc
        phi=Model_erm(feat,fc)
        phi=phi.to(self.cuda) 
        return phi
    
    
    def save_model(self):
        # Store the weights of the model
        
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model' + '.pth')
        
    
        
    def get_opt(self):
        # if self.args.opt == 'sgd':
        # opt= optim.SGD([
        #                 {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
        #     ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        # elif self.args.opt == 'adam':
        opt= optim.Adam(
            self.phi.parameters(),
            lr=0.001,#0.001,
            weight_decay=0.0
        )
        ##([
        #             {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
        #     ], lr= 0.001,weight_decay=0.0)
        
        return opt

    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)
    
    def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                           1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))

        return K

    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            #print('mean_diff + cova_diff')
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            mean_diff = (mean_x - mean_y)**2#.abs().pow(2).mean()
            mean_diff=mean_diff.mean()
            if len(x)==1 or len(y)==1:
                return mean_diff
            else:
                if len(x)==0  or len(y) ==0:
                    print(len(x))
                    print(len(y))
                    print('length problem!!!')
                cent_y = y - mean_y
                cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
                cent_x = x - mean_x
                cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
                cova_diff = (cova_x - cova_y)**2#.abs().pow(2).mean()
                cova_diff=cova_diff.mean()
                return mean_diff + cova_diff
            
            
    
    def mmd_regularization(self, features, d, nmb):###不应该用len(match_domains)应该直接用match——domains
        penalty= torch.tensor(0.0).to(self.cuda)
        l=len(nmb)
        for i in range(l):
            for j in range(i + 1, l):
                f_i= features[ d == nmb[i] ]
                f_j= features[ d == nmb[j] ]
                penalty += self.mmd(f_i, f_j)
            # f_i= features[ d == d_i ]
            # if len(f_i) == 0 :
            #     continue
            # for d_j in range(d_i + 1, l):
            #     f_j= features[ d == d_j ]
            #     if  len(f_j) == 0:
            #         continue

            #     penalty += self.mmd(f_i, f_j)
                # print('penalty')#第二轮开始全是nan
                # print(penalty)
        
        return penalty


    def get_test_accuracy(self, case):
           
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.phi(x_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size       

    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_mmd=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)
                
                #with autograd.detect_anomaly():
                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1)
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                
                #ERM
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)
                
                #MMD
                mmd_loss=torch.tensor(0.0).to(self.cuda)
                match_domains= torch.unique(d_e)
                #print(match_domains) tensor([0, 1, 2, 3, 4])
                
                class_labels= torch.unique(y_e)
                #mmd
                #nmb = len(match_domains)
                #coral
                nmb=match_domains.numpy().tolist()

                if self.conditional:
                    for y_c in range(len(class_labels)):                    
                        features_c= features[ y_e == y_c ]
                        d_c= d_e[ y_e == y_c ]
                        if len(torch.unique(d_c)) != nmb:
                            print('*********************************')
                            print('Error: Some classes not distributed across all the domains; issues for class conditional methods')
                            continue
                        mmd_loss+= self.mmd_regularization(features_c, d_c, nmb)
                else:
                    mmd_loss+= self.mmd_regularization(features, d_e, nmb)            

                if len(nmb) > 1:
                    mmd_loss /= (len(nmb) * (len(nmb) - 1) / 2)
                # print('mmd_loss')
                # print(mmd_loss)             
                penalty_mmd+= float(mmd_loss)
                # print('penalty_mmd')
                # print(penalty_mmd)
                
                #Backward Pass
                loss_e+= self.mmd_gamma*mmd_loss
                # print('loss')
                # print(loss_e)                
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del mmd_loss 
                del loss_e
                torch.cuda.empty_cache()
                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_mmd )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])

# class MMD(BaseAlgo):####这里有mmd和coral，Gaussian_kernel true 就是mmd,false 就是coral,两个都是matching类型的
#     def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
#         super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 

#         self.mmd_gamma= self.args.penalty_ws
#         self.gaussian= bool(self.args.gaussian)
#         self.conditional= bool(self.args.conditional)
#         if self.gaussian: 
#             self.kernel_type = "gaussian"
#         else:
#             self.kernel_type = "mean_cov"
        
#         #enc=nn.Sequential(*list(self.phi.children())[:-1])
#         self.featurizer = self.phi.predict_conv_net
#         self.classifier = self.phi.predict_fc_net
        
#         print('Initial Params: ', )
        
#     def my_cdist(self, x1, x2):
#         x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
#         x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
#         res = torch.addmm(x2_norm.transpose(-2, -1),
#                           x1,
#                           x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
#         return res.clamp_min_(1e-30)
    
#     def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
#                                            1000]):
#         D = self.my_cdist(x, y)
#         K = torch.zeros_like(D)

#         for g in gamma:
#             K.add_(torch.exp(D.mul(-g)))

#         return K

#     def mmd(self, x, y):
#         if self.kernel_type == "gaussian":
#             Kxx = self.gaussian_kernel(x, x).mean()
#             Kyy = self.gaussian_kernel(y, y).mean()
#             Kxy = self.gaussian_kernel(x, y).mean()
#             return Kxx + Kyy - 2 * Kxy
#         else:
#             mean_x = x.mean(0, keepdim=True)
#             mean_y = y.mean(0, keepdim=True)
#             cent_x = x - mean_x
#             cent_y = y - mean_y
#             cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
#             cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

#             mean_diff = (mean_x - mean_y).pow(2).mean()
#             cova_diff = (cova_x - cova_y).pow(2).mean()

#             return mean_diff + cova_diff
    
#     def mmd_regularization(self, features, d, nmb):
#         penalty= torch.tensor(0.0).to(self.cuda)
#         for d_i in range(nmb):
#             for d_j in range(d_i + 1, nmb):
#                 f_i= features[ d == d_i ]
#                 f_j= features[ d == d_j ]
#                 penalty += self.mmd(f_i, f_j)
#         return penalty
        
#     def train(self):
        
#         self.max_epoch=-1
#         self.max_val_acc=0.0
#         for epoch in range(self.args.epochs):   
                    
#             penalty_erm=0
#             penalty_mmd=0
#             train_acc= 0.0
#             train_size=0
                    
#             #Batch iteration over single epoch
#             for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
#         #         print('Batch Idx: ', batch_idx)

#                 self.opt.zero_grad()
#                 loss_e= torch.tensor(0.0).to(self.cuda)
                
#                 x_e= x_e.to(self.cuda)
#                 y_e= torch.argmax(y_e, dim=1).to(self.cuda)
#                 d_e= torch.argmax(d_e, dim=1)
                
#                 #Forward Pass
#                 features = self.featurizer(x_e)
#                 out = self.classifier(features)                
                
#                 #ERM
#                 erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
#                 loss_e+= erm_loss
#                 penalty_erm += float(loss_e)
                
#                 #MMD
#                 mmd_loss=torch.tensor(0.0).to(self.cuda)
#                 match_domains= torch.unique(d_e)
#                 class_labels= torch.unique(y_e)
#                 nmb = len(match_domains)

#                 if self.conditional:
#                     for y_c in range(len(class_labels)):                    
#                         features_c= features[ y_e == y_c ]
#                         d_c= d_e[ y_e == y_c ]
#                         if len(torch.unique(d_c)) != nmb:
#                             print('*********************************')
#                             print('Error: Some classes not distributed across all the domains; issues for class conditional methods')
#                             continue
#                         mmd_loss+= self.mmd_regularization(features_c, d_c, nmb)
#                 else:
#                     mmd_loss+= self.mmd_regularization(features, d_e, nmb)            

#                 if nmb > 1:
#                     mmd_loss /= (nmb * (nmb - 1) / 2)
                                
#                 penalty_mmd+= float(mmd_loss)
                
#                 #Backward Pass
#                 loss_e+= self.mmd_gamma*mmd_loss                
#                 loss_e.backward(retain_graph=False)
#                 self.opt.step()
                
#                 del erm_loss
#                 del mmd_loss 
#                 del loss_e
#                 torch.cuda.empty_cache()
                
#                 train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
#                 train_size+= y_e.shape[0]                
                        
   
#             print('Train Loss Basic : ',  penalty_erm, penalty_mmd )
#             print('Train Acc Env : ', 100*train_acc/train_size )
#             print('Done Training for epoch: ', epoch)
            
#             #Train Dataset Accuracy
#             self.train_acc.append( 100*train_acc/train_size )
            
#             #Val Dataset Accuracy
#             self.val_acc.append( self.get_test_accuracy('val') )
            
#             #Test Dataset Accuracy
#             self.final_acc.append( self.get_test_accuracy('test') )
            
#             #Save the model if current best epoch as per validation loss
#             if self.val_acc[-1] > self.max_val_acc:
#                 self.max_val_acc=self.val_acc[-1]
#                 self.max_epoch= epoch
#                 self.save_model()
                                
#             print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])