import sys
import numpy as np
import argparse
import copy
import random
import json
import time

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
import torch.autograd as autograd

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity,get_dataloader

###########用不要fc层的resnet作为特征提取器可以用robustdg的方法，[batch_size,512],fish的方法会变成[batch_size,512,1,1]
##########
##########  dann_inverse 把domain当作标签而标签当作domain，学习一个提取domain特征的model 
#########

class DANNIN(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
        
        self.conditional = bool(self.args.conditional)
        self.class_balance = False 
               
        
        self.featurizer = self.phi.featurizer
        self.classifier = self.phi.classifier
        self.discriminator = self.phi.discriminator
        self.class_embeddings = self.phi.class_embeddings

        
        self.grad_penalty= self.args.grad_penalty
        self.lambda_= self.args.penalty_ws
        self.d_steps_per_g_step= self.args.d_steps_per_g_step
        self.initial_lr= 0.01
        
        # Optimizers
        self.disc_opt = torch.optim.SGD(
            (list(self.discriminator.parameters()) + 
                list(self.class_embeddings.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)

        self.gen_opt = torch.optim.SGD(
            (list(self.featurizer.parameters()) + 
                list(self.classifier.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)     


    def get_test_accuracy(self, case):
        #import opacus
        #from opacus.grad_sample.grad_sample_module import GradSampleModule as gra
        
        # if self.args.dp_noise:
        #     opacus.autograd_grad_sample.disable_hooks()
            #self.privacy_engine.module.disable_hooks()
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
             
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                self.opt.zero_grad()
#                 print(x_e.shape)
#                 print(torch.cuda.memory_allocated())                
                x_e= x_e.to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)

                #Forward Pass
                z_e=self.phi.featurizer(x_e)
                out= self.phi.classifier(z_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == d_e ).item()
                test_size+= d_e.shape[0]
                
                # To avoid CUDA memory issues
                if self.args.dp_noise:
                    self.opt.zero_grad()

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        #gra.enable_hooks()        
        return 100*test_acc/test_size


    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0;
        for epoch in range(20):   
                    
            penalty_erm=0
            penalty_dann=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)
        
                all_x = x_e
                all_d = d_e
                all_z = self.featurizer(all_x)
                #print(all_x.shape) [64, 3, 224, 224]
                # print(all_z.shape)# [64, 512]
                if self.conditional:
                    disc_input = all_z + self.class_embeddings(all_d)
                else:
                    disc_input = all_z
                disc_out = self.discriminator(disc_input)
                #print(disc_out.shape) 64,6
                disc_labels = y_e
                # print('d_e') 64,1 [0,1,2,3,4,5]   
                if self.class_balance:
                    d_counts = F.one_hot(all_d).sum(dim=0)
                    weights = 1. / (d_counts[all_d] * d_counts.shape[0]).float()
                    disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
                    disc_loss = (weights * disc_loss).sum()
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels)
                    # print('disc_loss')#没变
                    # print(disc_loss)

                #Gen Loss
                all_preds = self.classifier(all_z)
                classifier_loss = F.cross_entropy(all_preds, all_d)
                gen_loss = (classifier_loss +
                            (self.lambda_ * -disc_loss)) ###modified

                penalty_erm += float(classifier_loss)
                penalty_dann += float(disc_loss)
                
                d_steps_per_g = self.d_steps_per_g_step
                if (epoch % (1+d_steps_per_g) < d_steps_per_g):
                    #print('disc_opt')
                    self.disc_opt.zero_grad()
                    disc_loss.backward()
                    #print(disc_loss)
                    self.disc_opt.step()
                else:
                    #print('gen_opt')
                    self.disc_opt.zero_grad()
                    self.gen_opt.zero_grad()
                    gen_loss.backward()
                    self.gen_opt.step()
                
                del classifier_loss
                del gen_loss 
                del disc_loss
                torch.cuda.empty_cache()
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == d_e ).item()
                train_size+= d_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_dann )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            self.val_acc.append( self.get_test_accuracy('test') )
            
            #Test Dataset Accuracy
            #self.final_acc.append( self.get_test_accuracy('test') )######test的数据要变
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc or self.val_acc[-1]==100:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            #print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])