import copy
import math
import torch
from torch import nn, autograd
import utils
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader,RandomSampler
import torch.nn as nn
import numpy as np
import logging
from utils import KL_between_normals,product_of_experts_two
from scipy.stats import bernoulli
import cv2
import random
from math import floor
from torch.autograd import Variable
import models
from mask_filter import Masked_BN2d
from collections import OrderedDict
import os

max_norm = 10

class Agent():
    def __init__(self, model, id, args, train_dataset=None, data_idxs=None, mask=None,val_dataset=None):
        self.id = id
        self.args = args
        self.error = 0
        self.data_idxs = data_idxs
        self.model = model
        self.W = {name: value for name, value in self.model.named_parameters()}
        # self.clean_backup_dataset = copy.deepcopy(train_dataset)        
        if self.args.data == "cifar100":
            self.n_cls = 100
        else:
            self.n_cls = 10

        if val_dataset != None:
            # self.val_dataset = val_dataset
            self.poisoned_val_loader = DataLoader(val_dataset, batch_size=self.args.bs, shuffle=False, num_workers=args.num_workers,
                                     pin_memory=False)
            
        self.dir_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
        self.dir_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)


        # poisoned datasets
        self.train_dataset = utils.DatasetSplit(train_dataset, data_idxs)
        if self.id < args.num_corrupt:
            # self.clean_backup_dataset = copy.deepcopy(train_dataset)            
            self.poison_idxs = utils.poison_dataset(self.train_dataset.dataset, args, data_idxs, agent_idx=self.id)            
            
            # print(f'train_dataset label is {self.train_dataset.targets[poison_index[0]]}')
        # check image
        # for i in range(len(train_dataset.targets)):
        #     if train_dataset.targets[i] == 7:
        #         x = train_dataset.data[i]
        #         # x = np.array(x.squeeze())
        #         # x = np.transpose(x,(1,2,0))
        #         # x = (x * 255).astype(np.uint8)
        #         import matplotlib.pyplot as plt
        #         from PIL import Image
        #         image = Image.fromarray(x, 'RGB')
        #         image.save(f'./showpic/client_{self.id}_clean_image_{i}_label.png')
        # get dataloader
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.args.bs, shuffle=True, \
                                       num_workers=args.num_workers, pin_memory=False, drop_last=True)
        # size of local dataset
        # self.train_poison_loader = copy.deepcopy(self.train_loader)

        self.n_data = len(self.data_idxs)
        # print(f'client {self.id} data is {self.n_data}')

        # self.mask = copy.deepcopy(mask)
        self.num_remove= None

        if self.args.attack == "DBA":
            self.dba_class = self.id % 4

        # for SLT

        # import gc
        # # Delete tensor when it is no longer needed
        # del train_dataset
        # gc.collect() 


    def check_poison_timing(self, round):

        if round > self.args.cease_poison or round < 6:
        # if round > self.args.cease_poison:
            # if self.id            
            train_dataset = utils.DatasetSplit(self.clean_backup_dataset, self.data_idxs)
            # print(f'clean_backup_dataset label is {self.clean_backup_dataset.targets[20326]}')
            self.train_loader = DataLoader(train_dataset, batch_size=self.args.bs, shuffle=True, \
                                       num_workers=self.args.num_workers, pin_memory=False, drop_last=True)
        elif round >=7 :    
            self.train_loader = copy.deepcopy(self.train_poison_loader)
            
    def test(self,model,device,testloader):

        model.eval()
        model.to(device)
        loss, total, correct = 0.0, 0.0, 0.0
        criterion = nn.CrossEntropyLoss().to(device)
        num_avg_train = 32
        accuracy_by_epoch_test = []
        label_7 = 0

        with torch.no_grad():
            for _, (images, labels) in enumerate(testloader):
                # label_7 += torch.sum(labels == torch.tensor(7))
                total += len(labels)
                images, labels = images.to(device), labels.to(device)
                
                encoder_Z_distr, decoder_logits, regL2R = model(images,num_avg_train)
                # decoder_logits = decoder_logits.permute(1, 2, 0)
                # batch_loss = criterion(decoder_logits, labels[:, None].expand(-1, num_avg_train))
                # loss += batch_loss.item()
                # Prediction
                decoder_logits_mean = torch.mean(decoder_logits, dim=0)
                pred_labels = torch.max(decoder_logits_mean, dim=1)[1]
                accuracy_test = torch.mean((pred_labels == labels).float())
                accuracy_by_epoch_test.append(accuracy_test.item())
                # _, pred_labels = torch.max(decoder_logits, 1)
                # pred_labels = torch.max(decoder_logits_mean, dim=1)[1]
                # pred_labels = pred_labels.view(-1)
                # correct += torch.sum(torch.eq(pred_labels, labels)).item()
                # total += len(labels)

        accuracy = np.mean(accuracy_by_epoch_test)
        accuracy = 100.00 * accuracy
        # print(f'label 7 count is {label_7},the total sample is {total} ')
        return accuracy

    def high_l2_layers(self,model):
        l2_norms = []
        self.model_depth = len(self.mask)
        for name,param in model.named_parameters():
            if 'weight' in name:
                l2_norms.append((name,param.norm(2).item()))
        
        # Sort the layers by L2 norm using torch.sort()
        sorted_l2_norms, sorted_indices = torch.sort(torch.tensor([norm for _, norm in l2_norms], dtype=torch.float32), descending=True)
        
        # Calculate the median L2 norm
        median_l2_norm = sorted_l2_norms[len(sorted_l2_norms) // 2].item()
        
        # Select layers with L2 norm higher than the median
        high_l2_layers = [(layer, norm) for layer, norm in zip([name for name, _ in l2_norms], sorted_l2_norms) if norm > median_l2_norm]

        # Print the selected layers
        select_layers_name = []
        # print("Layers with L2 norm higher than median:")
        for layer, norm in high_l2_layers:
            select_layers_name.append(layer)
            # print(f'{layer}: {norm}')

        return select_layers_name


    def local_train(self, global_model, criterion,round=None, temparature=10, alpha=0.3, neurotoxin_mask =None,updates_dict=None,last=False,server=None,poi_loader=None):
        """ Do a local training over the received global model, return the update """
        # print("the client {} local trainning".format(self.id))
        initial_global_model_params = parameters_to_vector([ global_model.state_dict()[name] for name in global_model.state_dict()]).detach()
        # if self.id  <  self.args.num_corrupt and self.args.attack != "A3FL":
        #     train_loader_rnd = self.check_poison_timing(round)
        # else:
        #     train_loader_rnd = self.train_loader
        
        #----------------------- pruning -----------------------#
        if round > 0 and self.id >= self.args.num_corrupt:
            # test_poi = self.test(global_model,self.args.device,poi_loader)
            # test_acc = self.test(global_model,self.args.device,self.train_loader)
            # print(f"client {self.id} before asr is {test_poi}")
            # print(f"client {self.id} before acc is {test_acc}")
            # _, clean_val = self.split_dataset(dataset=self.clean_backup_dataset, val_frac=self.args.val_ratio)
            # sampler = RandomSampler(data_source=clean_val, replacement=True,
            #                         num_samples =self.args.epoch_aggregation * self.args.bs)
            # clean_val_loader  = DataLoader(clean_val, batch_size=self.args.bs,
            #                             shuffle=False, sampler=sampler, num_workers=0)
            
            # local_net = models.get_model(self.args.device,self.args.data,norm_layer=Masked_BN2d)
            # local_net = models.get_model(self.args.device,self.args.data)
            # orig_state = copy.deepcopy(global_model.state_dict())
            # self.load_model(local_net,orig_state_dict=orig_state)
            global_model.to(self.args.device)
            global_model.train()
            criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
            parameters  = list(global_model.named_parameters())
            # mask_params = [v for n, v in parameters if "neuron_mask" in n]
            mask_optimizer = torch.optim.SGD(global_model.parameters(), lr=self.args.client_lr, momentum=0.95)                            ## For Mask-finetuning
            # nb_iterations = int(np.ceil(self.args.nb_epochs / self.args.epoch_aggregation))
            # # Step 3: train backdoored models
            nb_iterations = round*1
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(mask_optimizer, nb_iterations, 0.02)
            # N_c = len(clean_val)/10
            # init_model_mask = copy.deepcopy(local_net)
            # orig_state_mask = copy.deepcopy(local_net.state_dict())
            # nb_iterations = round
            for i in range(nb_iterations):
                lr = mask_optimizer.param_groups[0]['lr']
                train_loss, train_acc = self.NFT_Train(model=global_model, criterion=criterion, data_loader=self.train_loader,
                                            mask_opt=mask_optimizer)
                scheduler.step()
                # print('{} \t {:.4f} \t {:.4f}'.format((i + 1), 100*train_acc, train_loss))
            #----------------------- pruning -----------------------#
            # self.save_mask_scores(local_net.state_dict(), os.path.join(self.args.folder, 'mask_values.txt')) 
            # mask_values = self.read_data(os.path.join(self.args.folder, 'mask_values.txt'))
            # mask_values = sorted(mask_values, key=lambda x: float(x[2]))
            
            # self.load_model(init_model,orig_state_dict=orig_state)
            # self.load_model(init_model_mask,orig_state_dict=orig_state_mask)
            # ACC = self.test(local_net,self.args.device,clean_val_loader)
            # ASR = self.test(local_net,self.args.device,poi_loader)
            # print(f' init model acc is {ACC} \t and asr is {ASR} \t')
            # logging.info(f"client {self.id} \t After prun asr is {ASR} \t acc is {ACC}\t")
            # results = self.evaluate_by_threshold(
            #         init_model_mask, mask_values, pruning_max=self.args.pruning_max, pruning_step=self.args.pruning_step,
            #         criterion=criterion, clean_loader=self.train_loader, poison_loader=poi_loader
            #     )  
        #     results = self.evaluate_by_number(
        #     init_model_mask, mask_values, pruning_max=self.args.pruning_max, pruning_step=self.args.pruning_step,
        #     criterion=criterion, clean_loader=self.train_loader, poison_loader=poi_loader
        # )

            # prufied_state = copy.deepcopy(local_net.state_dict())
            # self.load_model(global_model,orig_state_dict=prufied_state)
            # test_poi_core = self.test(global_model,self.args.device,poi_loader)
            # ACC = self.test(global_model,self.args.device,clean_val_loader)
            # ASR = self.test(global_model,self.args.device,poi_loader)
            # print(f' init model acc is {ACC} \t and asr is {ASR} \t')
            # logging.info(f"client {self.id} \t load state to global model, asr is {ASR} \t acc is {ACC}\t")
            # print(f"client {self.id} Init local model after asr is {test_poi_core}")
            # logging.info(f"client {self.id} After prun asr is {test_poi_core}")
            # global_model = local_net
        # else:
        #     local_net = copy.deepcopy(global_model)
        # global_model.to(self.args.device)
        # global_model.train()
        lr = self.args.client_lr* (self.args.lr_decay)**round
        optimizer = torch.optim.SGD(global_model.parameters(), lr=lr, weight_decay=self.args.wd)
        local_eps = self.args.local_ep
        # if last:
        #     local_eps = max(self.args.last_local_ep, self.args.local_ep - self.args.local_rep_ep)
        # else:
        #     for name, param in global_model.named_parameters():
        #         param.requires_grad = True        
        #----------------------- pruning -----------------------#


        for iter in range(local_eps):
            
            sum_count = 0
            
            dir_g_Z_u = torch.zeros(1, self.args.dimZ, dtype=torch.float32,
                                            device=self.args.device)
            dir_g_Z_sigma = torch.ones(1, self.args.dimZ, dtype=torch.float32,
                                                device=self.args.device)

            
            for i, (inputs, labels) in enumerate(self.train_loader):
                # if iter==0:
                #     sum_count += len(labels)
                # if round == 101 and iter ==1 and self.id > self.args.num_corrupt:
                    # mean = torch.tensor([0.4914, 0.4822, 0.4465])
                    # std = torch.tensor([0.2023, 0.1994, 0.2010])
                    # inputs_cpu = images.cpu()
                    # images_re = inputs * std.view(3,1,1) + mean.view(3,1,1)
                    # images_re = torch.clamp(images_re, 0, 1)
                    # for k in range(len(labels)):
                    #     if labels[k] == 7:
                    #         x = images_re[k]
                    #         x = np.array(x)
                    #         x = np.transpose(x,(1,2,0))
                    #         x = (x * 255).astype(np.uint8)
                    #         import matplotlib.pyplot as plt
                    #         from PIL import Image
                    #         image = Image.fromarray(x, 'RGB')
                    #         image.save(f'./showpic11/clean/client_{self.id}_clean_image_batch_{i}_{k}-th_label.png')

                # if round == 1:
                #     sum_count  += torch.sum(labels == 7).item()
            
                
                optimizer.zero_grad()
                inputs, labels = inputs.to(device=self.args.device), \
                                    labels.to(device=self.args.device)
                for cls in range(len(labels)):
                    if cls == 0:
                        dir_g_Z_u = server.dir_global_Z_u[labels[cls]].clone().detach()
                        dir_g_Z_sigma = server.dir_global_Z_sigma[labels[cls]].clone().detach()
                    else:
                        dir_g_Z_u = torch.cat((dir_g_Z_u, server.dir_global_Z_u[labels[cls]].clone().detach()), 0).clone().detach()
                        dir_g_Z_sigma = torch.cat((dir_g_Z_sigma, server.dir_global_Z_sigma[labels[cls]].clone().detach()), 0).clone().detach()
                prior_Z_distr = dir_g_Z_u, dir_g_Z_sigma
                encoder_Z_distr, decoder_logits, regL2R= global_model(inputs, self.args.num_avg_train)
                decoder_logits_mean = torch.mean(decoder_logits, dim=0)
                loss = nn.CrossEntropyLoss(reduction='none')
                decoder_logits = decoder_logits.permute(1, 2, 0)
                cross_entropy_loss = loss(decoder_logits, labels[:, None].expand(-1, self.args.num_avg_train))
                cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
                I_ZX_bound = torch.mean(KL_between_normals(prior_Z_distr, encoder_Z_distr))
                minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
                regL2R = regL2R / len(labels)

                if self.id >= self.args.num_corrupt:
                    # total_loss = minusI_ZY_bound
                    total_loss = torch.mean(minusI_ZY_bound + self.args.beta * I_ZX_bound+self.args.L2R * regL2R) 
                    # total_loss = torch.mean(minusI_ZY_bound + self.args.L2R * regL2R)               
                else:
                    total_loss = minusI_ZY_bound
                # total_loss = torch.mean(minusI_ZY_bound + self.args.beta * I_ZX_bound)
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters=global_model.parameters(), max_norm=max_norm)

                optimizer.step()

            
            if iter == local_eps - 1:
                if self.n_data > self.args.bs:
                    for cls in range(len(labels)):
                        if self.dir_Z_u[labels[cls]] .equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and self.dir_Z_sigma[labels[cls]] .equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)):
                            self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach()
                        else:
                            q_distr = self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]]
                            encoder_Z_distr_cls = encoder_Z_distr[0][cls].clone().detach(), encoder_Z_distr[1][cls].clone().detach()
                            self.dir_Z_u[labels[cls]], self.dir_Z_sigma[labels[cls]] = product_of_experts_two(q_distr, encoder_Z_distr_cls)
                    
        if self.id<  self.args.num_corrupt:
            if self.args.attack=="fix_mask":
                self.mask = self.mask 
            elif self.args.attack == "neurotoxin":
                if len(neurotoxin_mask):
                    self.mask = neurotoxin_mask
                else:
                    self.mask = self.mask

        # poison_loss, (asr, _), fail_samples = utils.get_loss_n_accuracy(global_model, criterion,
        #                                                                         poi_loader, self.args,-1)
        # test_poi_core = self.test(global_model,self.args.device,poi_loader)
        # print(f"client {self.id} asr is {asr}")
        
        # print(f"client {self.id}local final asr is {test_poi_core}")
        # logging.info(f"client {self.id} Local final asr is {test_poi_core}")
        # train_state = copy.deepcopy(local_net.state_dict())
        # self.load_model(global_model,orig_state_dict=train_state)
        # if round > 200 and self.id >= self.args.num_corrupt:
        #     # test_poi_core = self.test(global_model,self.args.device,poi_loader)
        #     # print(f"client {self.id} before asr is {test_poi_core}")
        #     # _, clean_val = self.split_dataset(dataset=self.clean_backup_dataset, val_frac=self.args.val_ratio)
        #     # sampler = RandomSampler(data_source=clean_val, replacement=True,
        #     #                         num_samples =self.args.epoch_aggregation * self.args.bs)
        #     # clean_val_loader  = DataLoader(clean_val, batch_size=self.args.bs,
        #     #                             shuffle=False, sampler=sampler, num_workers=0)
            
        #     local_net = models.get_model(self.args.device,self.args.data,norm_layer=Masked_BN2d)
        #     orig_state = copy.deepcopy(global_model.state_dict())
        #     self.load_model(local_net,orig_state_dict=orig_state)
        #     local_net.to(self.args.device)
        #     criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
        #     parameters  = list(local_net.named_parameters())
        #     mask_params = [v for n, v in parameters if "neuron_mask" in n]
        #     mask_optimizer = torch.optim.SGD(mask_params, lr=self.args.client_lr, momentum=0.95)                            ## For Mask-finetuning
        #     nb_iterations = int(np.ceil(self.args.nb_epochs / self.args.epoch_aggregation))
        #     # # Step 3: train backdoored models
        #     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(mask_optimizer, nb_iterations, 0.02)
        #     N_c = len(clean_val)/10
        #     for i in range(nb_iterations):
        #         lr = mask_optimizer.param_groups[0]['lr']
        #         train_loss, train_acc = self.NFT_Train(model=local_net, criterion=criterion, data_loader=self.train_loader,
        #                                     mask_opt=mask_optimizer)
        #         scheduler.step()
        #         # print('{} \t {:.4f} \t {:.4f}'.format((i + 1), 100*train_acc, train_loss))
        #     #----------------------- pruning -----------------------#
        #     prufied_state = copy.deepcopy(local_net.state_dict())
        #     self.load_model(global_model,orig_state_dict=prufied_state)
        #     test_poi_core = self.test(global_model,self.args.device,poi_loader)
        #     print(f"client {self.id} Final local model after asr is {test_poi_core}")
        #     logging.info(f"client {self.id} After asr is {test_poi_core}")
            
        # if round == 1:
        #     print(f"client {self.id}local sample label is 7 count {sum_count},the total sample is {self.n_data} ")
        #     logging.info(f"client {self.id}local sample label is 7 count {sum_count},the total sample is {self.n_data} ")
        
        with torch.no_grad():
            after_train = parameters_to_vector([ global_model.state_dict()[name] for name in global_model.state_dict()]).detach()
            self.update = (after_train - initial_global_model_params)
        
        torch.cuda.empty_cache()
        import gc
        del after_train,initial_global_model_params
        gc.collect() 

        return self.update


    def synchronize_with_server(self, global_model_rnd, w_glob_keys):
        # W_client = W_server  下发全局模型,表征学习，不下发最后一层网络
        golbal_w_rnd = {name: value for name, value in global_model_rnd.named_parameters()}
        if self.args.method != 'fedavg' and self.args.method != 'ditto':
            for name in self.W:
                if name in w_glob_keys:
                    self.W[name].data = golbal_w_rnd[name].data.clone()
        else:
            self.model = copy.deepcopy(global_model_rnd)
            self.W = golbal_w_rnd

    def NFT_Train(self, model, criterion, mask_opt, data_loader):
        model.train()
        # total_correct = 0
        total_loss    = 0.0
        nb_samples    = 0
        # total = 0
        correct = 0
        sum_count = 0
        ## Train the model for 1 epoch
        for i, (images, labels) in enumerate(data_loader):
            nb_samples += images.size(0)
            inputs, targets = images.to(self.args.device), labels.to(self.args.device)

            tmp_z_dir, decoder_logits_cln, tmp_L2R= model(inputs, self.args.num_avg_train)
            decoder_logits_cln = torch.mean(decoder_logits_cln, dim=0)
            loss_cln = criterion(decoder_logits_cln,targets)

            inputs, targets_a, targets_b, lam = self.mixup_data(inputs, targets,
                                                        alpha=1)
            inputs, targets_a, targets_b = map(Variable, (inputs,
                                                        targets_a, targets_b))
            # outputs = model(inputs)
            tmp_z_dir, decoder_logits, tmp_L2R= model(inputs, self.args.num_avg_train)
            decoder_logits_mean = torch.mean(decoder_logits, dim=0)
            loss = self.mixup_criterion(criterion, decoder_logits_mean, targets_a, targets_b, lam)
            mask_opt.zero_grad()
            L1, L2, L_inf = self.Regularization(model)
            tot_loss     = loss + 0.001*L1/50 
            # tot_loss     = -loss  + loss_cln
            # tot_loss     = loss  + 0.001*L2/50 
            # tot_loss     = loss 
            tot_loss.backward()
            mask_opt.step()
            self.mask_clip(model)
            predicted = torch.max(decoder_logits_mean, dim=1)[1]
            # accuracy_test = torch.mean((pred_labels == labels).float())
            ## Claculate the train accuracy 
            # _, predicted = torch.max(outputs.data, 1)
            # total += targets.size(0)
            correct += (lam * predicted.eq(targets_a.data).cpu().sum().float()
                        + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())

            total_loss += tot_loss.item()

        # print(f'label 7 count is {sum_count},sample count is {nb_samples}')
        loss = total_loss / len(data_loader)
        acc = float(correct) / nb_samples
        return loss, acc 
    def mixup_data(self,x, y, alpha=1.0):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1
        # lam = 0.5

        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(device=self.args.device)
        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam
    
    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

    def mask_clip(self, model, upper=1):
        params = [param for name, param in model.named_parameters() if 'neuron_mask' in name]
        count_layer = 1
        count = 0
        with torch.no_grad():
            for param in params:
                # print('')
                # param = torch.where(param < 0.00005, torch.tensor(0), param)
                # print(f'param max {torch }')
                # count += torch.sum(param < torch.mean(param))
                param.clamp_(self.args.alpha_pru*math.exp(-self.args.beta_pru*count_layer), upper)
                count_layer += 1
        # print(f'client {self.id} neuron mask value < 0.015 count is {count}, {self.args.alpha_pru*math.exp(-self.args.beta_pru*count_layer)}')
    def Regularization(self,model):
        L1=0
        L2=0
        L_inf = 0
        for name, param in model.named_parameters():
            if 'neuron_mask' in name:
                L1 += torch.sum(torch.abs(1-param))
                L2 += torch.norm(param, 2)
                L_inf += torch.max(torch.abs(1-param))
        # for name, module in model.named_parameters():
        return L1, L2, L_inf
    
    def load_model(self,net, orig_state_dict):
        if 'state_dict' in orig_state_dict.keys():
            orig_state_dict = orig_state_dict['state_dict']
        if "state_dict" in orig_state_dict.keys():
            orig_state_dict = orig_state_dict["state_dict"]

        new_state_dict = OrderedDict()
        for k, v in net.state_dict().items():
            if k in orig_state_dict.keys():
                new_state_dict[k] = orig_state_dict[k]
            elif 'running_mean_noisy' in k or 'running_var_noisy' in k or 'num_batches_tracked_noisy' in k:
                new_state_dict[k] = orig_state_dict[k[:-6]].clone().detach()
            else:
                new_state_dict[k] = v

        net.load_state_dict(new_state_dict)

    def split_dataset(self,dataset, val_frac=0.1, perm=None):
        """
        :param dataset: The whole dataset which will be split.
        :param val_frac: the fraction of validation set.
        :param perm: A predefined permutation for sampling. If perm is None, generate one.
        :return: A training set + a validation set
        """
        if perm is None:
            perm = np.arange(len(dataset))
            np.random.shuffle(perm)
        nb_val = int(val_frac * len(dataset))

        # generate the training set
        train_set = copy.deepcopy(dataset)
        train_set.data = train_set.data[perm[nb_val:]]
        train_set.targets = np.array(train_set.targets)[perm[nb_val:]].tolist()

        # generate the test set
        val_set = copy.deepcopy(dataset)
        val_set.data = val_set.data[perm[:nb_val]]
        val_set.targets = np.array(val_set.targets)[perm[:nb_val]].tolist()

        # print(f'val_set size is {val_set.targets.count(7)}')
        return train_set, val_set
    
    def save_mask_scores(self,state_dict, file_name):
        mask_values = []
        count = 0
        for name, param in state_dict.items():
            if 'neuron_mask' in name:
                for idx in range(param.size(0)):
                    neuron_name = '.'.join(name.split('.')[:-1])
                    mask_values.append('{} \t {} \t {} \t {:.4f} \n'.format(count, neuron_name, idx, param[idx].item()))
                    count += 1
        with open(file_name, "w") as f:
            f.write('No \t Layer Name \t Neuron Idx \t Mask Score \n')
            f.writelines(mask_values)

    def read_data(self,file_name):
        import pandas as pd
        tempt = pd.read_csv(file_name, sep='\s+', skiprows=1, header=None)
        layer = tempt.iloc[:, 1]
        idx = tempt.iloc[:, 2]
        value = tempt.iloc[:, 3]
        mask_values = list(zip(layer, idx, value))
        return mask_values

    def pruning(self,net, neuron):
        state_dict = net.state_dict()
        weight_name = '{}.{}'.format(neuron[0], 'weight')
        # state_dict[weight_name][int(neuron[1])] = 0.0
        state_dict[weight_name][int(neuron[1])].zero_()
        net.load_state_dict(state_dict)

    # def mask_clip(self,args, model, upper=1):
    #     params = [param for name, param in model.named_parameters() if 'neuron_mask' in name]
    #     count_layer = 1
    #     with torch.no_grad():
    #         for param in params:
    #             param.clamp_(args.alpha*math.exp(-args.beta*count_layer), upper)
    #             count_layer += 1

    def evaluate_by_threshold(self,model, mask_values, pruning_max, pruning_step, criterion, clean_loader, poison_loader):
        results = []
        thresholds = np.arange(0.0, pruning_max + pruning_step, pruning_step)
        # thresholds = np.flip(thresholds)
        start = 0
        # cl_loss, cl_acc = test(model=model, criterion=criterion, data_loader=clean_loader)
        # print('Before ACC {:.2f} and loss \t {}'.format(cl_loss, cl_acc))
        for threshold in thresholds:
            idx = start
            for idx in range(start, len(mask_values)):
                if float(mask_values[idx][2]) < threshold:
                    self.pruning(model, mask_values[idx])
                    start += 1
                else:
                    break
            layer_name, neuron_idx, value = mask_values[idx][0], mask_values[idx][1], mask_values[idx][2]
            cl_acc = self.test(model, self.args.device,clean_loader)
            po_acc = self.test(model, self.args.device,poison_loader)
            # cl_test_loss, ACC =NFT_Test(model=model, criterion=criterion, data_loader=clean_loader)
            # po_test_loss, ASR =NFT_Test(model=model, criterion=criterion, data_loader=poison_loader)

            print('{:.2f} \t {} \t {} \t {} \t {:.4f} \t {:.4f}'.format(
                start, layer_name, neuron_idx, threshold, po_acc, cl_acc))
            # results.append('{:.2f} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}\n'.format(
            #     start, layer_name, neuron_idx, threshold, po_loss, po_acc, cl_loss, cl_acc))
        return results
    def evaluate_by_number(self,model, mask_values, pruning_max, pruning_step, criterion, clean_loader, poison_loader):
        results = []
        mask_len = len(mask_values)
        prun_count = int(pruning_max*mask_len)
        nb_step = 100
        # nb_max = int(np.ceil(int(np.ceil(pruning_max))))
        # nb_step = int(np.ceil(pruning_step))
        for start in range(0, prun_count + 1, nb_step):
            i = start
            for i in range(start, start + nb_step):
                self.pruning(model, mask_values[i])
            layer_name, neuron_idx, value = mask_values[i][0], mask_values[i][1], mask_values[i][2]
            cl_acc = self.test(model, self.args.device,clean_loader)
            po_acc = self.test(model, self.args.device,poison_loader)
            print('{:.2f} \t {} \t {} \t {} \t {:.4f} \t {:.4f}'.format(
                start, layer_name, neuron_idx, value, po_acc, cl_acc))
            # print('{} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format(
            #     i+1, layer_name, neuron_idx, value, po_loss, po_acc, cl_loss, cl_acc))
            # results.append('{} \t {} \t {} \t {} \t {:.4f} \t {:.4f} \t {:.4f} \t {:.4f}'.format(
            #     i+1, layer_name, neuron_idx, value, po_loss, po_acc, cl_loss, cl_acc))
        return results
