#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy

from utils import increase_loss_utils as baseline1_utils
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
from torch import nn, autograd, optim
import torch.nn.functional as F
import numpy as np
import random
from sklearn import metrics
from utils import ada_hessain
from utils import increase_loss_utils


def extract_logits(output):
    """Return tensor logits from various output types."""
    if isinstance(output, torch.Tensor):
        return output
    if hasattr(output, "logits"):
        return output.logits
    if isinstance(output, (list, tuple)):
        return output[0]
    raise TypeError(f"Unsupported model output type: {type(output)}")
from utils.forget_event import order_examples_of_forget
import os
from scipy.stats import truncnorm  

def generate_truncated_normal_random_numbers(mean, std_dev, size):   
    a, b = 0, 1  
    
    truncated_normal = truncnorm((a - mean) / std_dev, (b - mean) / std_dev, loc=mean, scale=std_dev)  
    random_numbers = truncated_normal.rvs(size=size)  
    
    return random_numbers  

def freeze_layers(model, layer_names):
    bn_layers = []
    for name in layer_names:
        if name.startswith('conv'):
            suffix = name[4:]
            bn_layers.append(f'bn{suffix}')
    
    all_freeze = layer_names + bn_layers
    
    for name, param in model.named_parameters():
        layer_part = name.split('.')[0]
        if layer_part in all_freeze:
            param.requires_grad = False
    return model



def clip_image(x):
    return torch.clamp(x, -1.0, 1.0)
def all2one_target_transform(x, attack_target=1):
    return torch.ones_like(x) * attack_target

# from evaluate import test
def test(args, test_loader, net, example_stats=None):
    net.eval()
    test_loss = 0
    correct = 0
    loss_fun = nn.CrossEntropyLoss()
    for data, target, index in test_loader:
        data = data.to(args.device).float()
        target = target.to(args.device).long()

        output = net(data)

        test_loss += loss_fun(output, target).item()
        pred = output.data.max(1)[1]

        correct += pred.eq(target.view(-1)).sum().item()

        if args.record_forget_event is True and example_stats is not None:
            acc = pred == target
            for i, idx in enumerate(index):
                idx = idx.item()
                output_correct_class = output.data[i, target[i].item()]
                sorted_output, _ = torch.sort(output.data[i, :])
                if acc[i]:
                    # Example classified correctly, highest incorrect class is 2nd largest output
                    output_highest_incorrect_class = sorted_output[-2]
                else:
                    # Example misclassified, highest incorrect class is max output
                    output_highest_incorrect_class = sorted_output[-1]
                margin = output_correct_class.item() - output_highest_incorrect_class.item()
                # index_stats = example_stats.get(idx,[[], [], []])
                # index_stats[0].append(loss[i].item())
                index_stats = example_stats.get(idx, [[], []])
                index_stats[0].append(acc[i].sum().item())
                index_stats[1].append(margin)
                example_stats[idx] = index_stats

    if args.record_forget_event is True and example_stats is not None:
        acc_test = correct / len(test_loader.dataset) * 100
        index_stats = example_stats.get('acc', [[], []])
        index_stats[1].append(acc_test)
        example_stats['acc'] = index_stats

    # print("Local accuracy: {:.2f}".format(acc_test))
    return test_loss / len(test_loader), correct / len(test_loader.dataset), example_stats

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


def calculate_infonce(args, f_now, f_pos, f_neg):
    f_proto = torch.cat((f_pos, f_neg), dim=0)
    l = torch.cosine_similarity(f_now, f_proto, dim=1)
    l = l / args.infoNCET

    exp_l = torch.exp(l)
    exp_l = exp_l.view(1, -1)
    pos_mask = [1 for _ in range(f_pos.shape[0])] + [0 for _ in range(f_neg.shape[0])]
    pos_mask = torch.tensor(pos_mask, dtype=torch.float32).to(args.device)
    pos_mask = pos_mask.view(1, -1)
    # pos_l = torch.einsum('nc,ck->nk', [exp_l, pos_mask])
    pos_l = exp_l * pos_mask
    sum_pos_l = pos_l.sum(1)
    sum_exp_l = exp_l.sum(1)
    infonce_loss = -torch.log(sum_pos_l / sum_exp_l)
    return infonce_loss


def hierarchical_info_loss(args, f_now, label, all_f, mean_f, all_global_protos_keys):

    f_pos = all_f[all_global_protos_keys == label.item()][0].to(args.device)
    f_neg = torch.cat(all_f[all_global_protos_keys != label.item()].tolist()).to(args.device)
    xi_info_loss = calculate_infonce(args,f_now, f_pos, f_neg)
    mean_f_pos = mean_f[all_global_protos_keys == label.item()][0].to(args.device)
    mean_f_pos = mean_f_pos.view(1, -1)
    # mean_f_neg = torch.cat(list(np.array(mean_f)[all_global_protos_keys != label.item()]), dim=0).to(self.device)
    # mean_f_neg = mean_f_neg.view(9, -1)

    loss_mse = nn.MSELoss()

    cu_info_loss = loss_mse(f_now.to(torch.float32), mean_f_pos.to(torch.float32))

    hierar_info_loss = xi_info_loss + cu_info_loss
    return hierar_info_loss



def agg_func(protos):
    """
    Returns the average of the weights.
    """

    for [label, proto_list] in protos.items():
        if len(proto_list) > 1:
            proto = 0 * proto_list[0].data
            for i in proto_list:
                proto += i.data
            protos[label] = proto / len(proto_list)
        else:
            protos[label] = proto_list[0]

    return protos


"""
from ada_hessian import AdaHessian
...
model = YourModel()
optimizer = AdaHessian(model.parameters())
...
for input, output in data:
  optimizer.zero_grad()
  loss = loss_function(output, model(input))
  loss.backward(create_graph=True)  # this is the important line! 🧐
  optimizer.step()
...
"""


class DomainClientUpdate_(object):
    def __init__(self, args,train_loader):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train =train_loader
        

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        #optimizer = ada_hessain.AdaHessian(net.parameters())

        epoch_loss = []
        for iter in range(self.args.local_ep):
            torch.cuda.empty_cache()
            batch_loss = []
            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            
            
        return net, net.state_dict(), sum(epoch_loss) / len(epoch_loss)


class DomainClientUpdate_Hesian(object):
    def __init__(self, args,train_loader):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train =train_loader

    def train(self, net):
        net.train()
        # train and update
        # optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        optimizer = ada_hessain.AdaHessian(net.parameters())

        epoch_loss = []
        for iter in range(self.args.local_ep):
            #torch.cuda.empty_cache()
            batch_loss = []
            for batch_idx, (images, labels,idx) in enumerate(self.ldr_train):
                torch.cuda.empty_cache()
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                optimizer.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)
                loss.backward(create_graph=True)
                # loss.backward()
                # grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)

                optimizer.step()
                # for param in net.parameters():
                #     param.grad = None

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
                # torch.autograd.grad(loss, net.parameters(), retain_graph=False, create_graph=False)
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


class DomainClientUnlearningBl3(object):
    def __init__(self, args, train_loader, threshold):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train =train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None

    def train(self, net, net_ref,net_unlearning_client):

        # train and update
        #optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        #optimizer = ada_hessain.AdaHessian(net.parameters())
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.9)
        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        epoch_loss = []
        for iter in range(self.args.num_local_unlearn_epochs):
            print('------------', iter)
            if flag:
                break
            batch_loss = []

            for batch_idx, (images, labels,idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)

                loss_joint = -loss  # negate the loss for gradient ascent

                loss_joint.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)

                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, net_ref)
                    print("/-/", distance,self.threshold)
                    # ℓ2-norm ball radius δ  = self.threshold
                    if distance > self.threshold:

                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - nn.utils.parameters_to_vector(
                            net_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(net_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())
                        distance = baseline1_utils.get_distance(net, net_ref)

                distance_ref_unlearning_client = baseline1_utils.get_distance(net, net_unlearning_client)
                print('Distance from the unlearned model to client ',self.args.unlearning_client,':', distance_ref_unlearning_client.item())

                # Early stopping threshold τ : distance_threshold
                if distance_ref_unlearning_client > self.distance_threshold:
                    flag = True
                    break

                if self.num_updates_in_epoch is not None and batch_idx >= self.num_updates_in_epoch:
                    break

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            if len(batch_loss) == 0:
                batch_loss.append(0)
            print('Train Loss: {:.4f}'.format(sum(batch_loss)/len(batch_loss)))
            # epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict()
    
class DomainClientUnlearning_maskil(object):
    def __init__(self, args, train_loader, threshold, mask_dict=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train =train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None
        self.mask_dict = mask_dict or {}

    def train(self, net, net_ref,net_unlearning_client):

        # train and update
        #optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        #optimizer = ada_hessain.AdaHessian(net.parameters())
        optimizer = MaskedSGD(
            params=net.parameters(),
            mask_dict=self.mask_dict,
            lr=self.args.lr,
            momentum=0.9,
        )
        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        epoch_loss = []
        for iter in range(self.args.num_local_unlearn_epochs):
            print('------------', iter)
            if flag:
                break
            batch_loss = []

            for batch_idx, (images, labels,idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)

                loss_joint = -loss  # negate the loss for gradient ascent

                loss_joint.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)

                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, net_ref)
                    print("/-/", distance,self.threshold)
                    # ℓ2-norm ball radius δ  = self.threshold
                    if distance > self.threshold:

                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - nn.utils.parameters_to_vector(
                            net_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(net_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())
                        distance = baseline1_utils.get_distance(net, net_ref)

                distance_ref_unlearning_client = baseline1_utils.get_distance(net, net_unlearning_client)
                print('Distance from the unlearned model to client ',self.args.unlearning_client,':', distance_ref_unlearning_client.item())

                # Early stopping threshold τ : distance_threshold
                if distance_ref_unlearning_client > self.distance_threshold:
                    flag = True
                    break

                if self.num_updates_in_epoch is not None and batch_idx >= self.num_updates_in_epoch:
                    break

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            if len(batch_loss) == 0:
                batch_loss.append(0)
            print('Train Loss: {:.4f}'.format(sum(batch_loss)/len(batch_loss)))
            # epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict()
    
class DomainClientUnlearningBl3_kl(object):
    def __init__(self, args, train_loader, threshold):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None

    def train(self, net, net_ref, net_unlearning_client):
        fixed_ref = copy.deepcopy(net_ref).to(self.args.device)
        for p in fixed_ref.parameters():
            p.requires_grad = False

        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.9)
        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        
        for iter in range(self.args.num_local_unlearn_epochs):
            print('------------', iter)
            if flag: break
            
            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                
                output = net(images)
                log_probs = extract_logits(output)
                loss_ce = self.loss_func(log_probs, labels)
                loss_asc = -loss_ce
                
                param_diff = sum(
                    torch.norm(p - p_ref) 
                    for p, p_ref in zip(net.parameters(), fixed_ref.parameters())
                )
                loss_param = self.args.alpha * param_diff
                
                with torch.no_grad():
                    ref_logits = fixed_ref(images)
                loss_kl = self.args.beta * F.kl_div(
                    input=F.log_softmax(log_probs, dim=1),
                    target=F.softmax(ref_logits, dim=1),
                    reduction='batchmean'
                )
                
                total_loss = loss_asc + loss_param + loss_kl
                total_loss.backward()
                
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)
                
                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, fixed_ref)
                    if distance > self.threshold:
                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - \
                                  nn.utils.parameters_to_vector(fixed_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(fixed_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())
                    
                    distance_ref_unlearning = baseline1_utils.get_distance(net, net_unlearning_client)
                    if distance_ref_unlearning > self.distance_threshold:
                        flag = True
                        break

        return net.state_dict()

class DomainClientUnlearningBl3frz(object):
    def __init__(self, args, train_loader, threshold, freeze_layers=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None
        self.freeze_layers = freeze_layers

    @staticmethod
    def _freeze_layers(model, layer_names):
        if not layer_names:
            return model
            
        bn_layers = []
        for name in layer_names:
            if name.startswith('conv'):
                suffix = name[4:]
                bn_layers.append(f'bn{suffix}')
        
        all_freeze = layer_names + bn_layers
        
        for name, param in model.named_parameters():
            layer_part = name.split('.')[0]
            if layer_part in all_freeze:
                param.requires_grad = False
        return model

    def train(self, net, net_ref, net_unlearning_client):
        if self.freeze_layers:
            net = self._freeze_layers(net, self.freeze_layers)

        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, net.parameters()),
            lr=self.args.frzulr,
            momentum=0.9
        )

        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        epoch_loss = []
        
        for iter in range(self.args.num_local_unlearn_epochs):
            print('------------', iter)
            if flag:
                break
            batch_loss = []

            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)

                loss_joint = -loss  # negate the loss for gradient ascent

                loss_joint.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)

                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, net_ref)
                    print("/-/", distance, self.threshold)
                    # ℓ2-norm ball radius δ  = self.threshold
                    if distance > self.threshold:
                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - nn.utils.parameters_to_vector(
                            net_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(net_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())
                        distance = baseline1_utils.get_distance(net, net_ref)

                distance_ref_unlearning_client = baseline1_utils.get_distance(net, net_unlearning_client)
                print('Distance from the unlearned model to client ', self.args.unlearning_client, ':', distance_ref_unlearning_client.item())

                # Early stopping threshold τ : distance_threshold
                if distance_ref_unlearning_client > self.distance_threshold:
                    flag = True
                    break

                if self.num_updates_in_epoch is not None and batch_idx >= self.num_updates_in_epoch:
                    break

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            if len(batch_loss) == 0:
                batch_loss.append(0)
            print('Train Loss: {:.4f}'.format(sum(batch_loss)/len(batch_loss)))
        return net.state_dict()
    
class DomainClientUnlearningBl3frz_kl(object):
    def __init__(self, args, train_loader, threshold, freeze_layers=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None
        self.freeze_layers = freeze_layers

    @staticmethod
    def _freeze_layers(model, layer_names):
        if not layer_names:
            return model
            
        bn_layers = []
        for name in layer_names:
            if name.startswith('conv'):
                suffix = name[4:]
                bn_layers.append(f'bn{suffix}')
        
        all_freeze = layer_names + bn_layers
        
        for name, param in model.named_parameters():
            layer_part = name.split('.')[0]
            if layer_part in all_freeze:
                param.requires_grad = False
        return model

    def train(self, net, net_ref, net_unlearning_client):
        if self.freeze_layers:
            net = self._freeze_layers(net, self.freeze_layers)

        global_ref = copy.deepcopy(net_ref).to(self.args.device)
        for p in global_ref.parameters():
            p.requires_grad = False

        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, net.parameters()),
            lr=self.args.frzulr,
            momentum=0.9
        )

        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        
        for iter in range(self.args.num_local_unlearn_epochs):
            print('------------', iter)
            if flag: break
            
            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                
                output = net(images)
                log_probs = extract_logits(output)
                
                loss_ce = self.loss_func(log_probs, labels)
                
                param_diff = sum(
                    torch.norm(p - gp) 
                    for p, gp in zip(net.parameters(), global_ref.parameters())
                )
                loss_param = self.args.alpha * param_diff
                
                with torch.no_grad():
                    ref_logits = global_ref(images)
                loss_kl = self.args.beta * F.kl_div(
                    F.log_softmax(log_probs, dim=1),
                    F.softmax(ref_logits, dim=1),
                    reduction='batchmean'
                )
                
                total_loss = -loss_ce + loss_param + loss_kl  # ▲▲▲
                
                total_loss.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)
                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, net_ref)
                    print("/-/", distance, self.threshold)
                    if distance > self.threshold:
                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - \
                                  nn.utils.parameters_to_vector(net_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(net_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())
                        
                    distance_ref_unlearning = baseline1_utils.get_distance(net, net_unlearning_client)
                    print(f'Distance to client {self.args.unlearning_client}:', distance_ref_unlearning.item())
                    if distance_ref_unlearning > self.distance_threshold:
                        flag = True
                        break

                if self.args.verbose and batch_idx % 10 == 0:
                    print(f'Update Epoch: {iter} [{batch_idx * len(images)}/{len(self.ldr_train.dataset)}'
                          f' ({100. * batch_idx / len(self.ldr_train):.0f}%)] Loss: {loss_ce.item():.6f}')

        return net.state_dict()

class DomainClientUpdate_record_org(object):
    def __init__(self, args,train_loader,test_loader,example_stats_local,client):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.client = client

        if self.args.unlearning_client == self.client and args.baseline == 5:
            # print("rl",self.ldr_train.dataset.labels,len(self.ldr_train.dataset.labels))
            print("rl")
            self.ldr_train = copy.deepcopy(train_loader)
            self.ldr_train.dataset.labels = np.random.randint(0, self.args.num_classes,self.ldr_train.dataset.labels.shape).astype(np.int64).squeeze()
        else:
            self.ldr_train =train_loader
        
        self.ldr_test =test_loader
        self.example_stats_local = example_stats_local

        
    def train(self, net, mask = None,bk =True):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum,weight_decay=1e-4)
        #optimizer = ada_hessain.AdaHessian(net.parameters())

        epoch_loss = []
        for iter in range(self.args.local_ep):
            net.train()
            torch.cuda.empty_cache()
            batch_loss = []

            for batch_idx, (images, labels,idx) in enumerate(self.ldr_train):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)
                loss.backward()

                if mask:
                    for name, param in net.named_parameters():
                        if param.grad is not None:
                            param.grad *= mask[name]
                optimizer.step()

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            
            
            if iter == 0 and self.args.record_forget_event is True and self.example_stats_local is not None:
                
                train_loss, train_acc,self.example_stats_local[0][self.client] = test(args=self.args, test_loader = self.ldr_train,net=net,example_stats=(self.example_stats_local)[0][self.client])              
                test_loss, test_acc,self.example_stats_local[1][self.client] = test(args=self.args, test_loader = self.ldr_test,net=net,example_stats=(self.example_stats_local)[1][self.client])
                # print(' Local Epoch0 | Train Loss: {:.4f} | Train Acc: {:.4f}'.format(train_loss,train_acc))
                # print(' Local Epoch0 | Test  Loss: {:.4f} | Test  Acc: {:.4f}'.format(test_loss,test_acc))
                # print("Record Local forget event...")
            if self.args.unlearning_client == self.client and self.args.remove_sample is not True and bk is True:
                print(f"client {self.client}.......Break")
                break
        return net, net.state_dict(), sum(epoch_loss) / len(epoch_loss), self.example_stats_local
    
class DomainClientUpdate_avg:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()

    def train(self, net):
        optimizer = torch.optim.SGD(
            net.parameters(), 
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)
                
                optimizer.step()
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    
class DomainClientUpdate_avg_frozenlast:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()

    def train(self, net):
        classifier_params = []
        if hasattr(net, 'classifier'):  
            classifier_params = list(net.classifier[-1].parameters())
        elif hasattr(net, 'fc3'):
            classifier_params = list(net.fc3.parameters())
        elif hasattr(net, 'fc'):
            classifier_params = list(net.fc.parameters())
        else:
            raise ValueError("模型没有识别到分类层，请检查模型结构")

        for param in classifier_params:
            param.requires_grad = False

        trainable_params = filter(lambda p: p.requires_grad, net.parameters())
        optimizer = torch.optim.SGD(
            trainable_params,
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )

        net.train()
        total_loss = 0.0

        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")

                data, target = data.to(self.args.device), target.to(self.args.device)
                optimizer.zero_grad()
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                loss.backward()

                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, net.parameters()), self.args.clip_grad)

                optimizer.step()
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss

        for param in classifier_params:
            param.requires_grad = True

        return copy.deepcopy(net), total_loss / self.args.local_ep

    
class MaskedSGD(torch.optim.SGD):
    """
    SGD optimizer that applies a per-parameter gradient mask before stepping.
    mask_dict: {param_name: torch.Tensor mask same shape as param.grad}
    """
    def __init__(self, params, mask_dict, lr=0.01, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        super().__init__(params, lr=lr, momentum=momentum,
                         dampening=dampening, weight_decay=weight_decay,
                         nesterov=nesterov)
        self.mask_dict = mask_dict
        # build mapping param id -> name
        self._param_id_to_name = {}
        names = list(mask_dict.keys())
        for p, name in zip(self.param_groups[0]['params'], names):
            self._param_id_to_name[id(p)] = name

    def step(self, closure=None):
        # apply mask to gradients
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                nm = self._param_id_to_name.get(id(p), None)
                if nm is not None and nm in self.mask_dict:
                    p.grad.data.mul_(self.mask_dict[nm].to(p.device))
        super().step(closure)

def generate_federated_mask(delta_history, mask_ratio, model_param_names, net_glob, global_size, num_clients, client_id, device):
    """
    修正说明：移除对 args.mask_ratio 的依赖，直接使用传入的 mask_ratio 作为参数。
    """
    aggregated_deltas = {}
    for delta in delta_history:
        for k in delta:
            if k not in aggregated_deltas:
                aggregated_deltas[k] = []
            aggregated_deltas[k].append(delta[k].float().to(device))

    mean_abs_deltas = {}
    for param_name in model_param_names:
        if param_name in aggregated_deltas:
            mean_abs_deltas[param_name] = torch.mean(
                torch.stack([torch.abs(d) for d in aggregated_deltas[param_name]]), dim=0
            )
        else:
            param_shape = net_glob.state_dict()[param_name].shape
            mean_abs_deltas[param_name] = torch.zeros(param_shape, device=device)

    total_elements = sum(p.numel() for p in net_glob.parameters())
    threshold_index = int(total_elements * mask_ratio)

    all_values = torch.cat([mean_abs_deltas[name].flatten() for name in model_param_names]).to(device)
    sorted_indices = torch.argsort(all_values, descending=True)

    selected_global_indices = sorted_indices[:threshold_index]
    mask = torch.zeros_like(all_values, dtype=torch.bool, device=device)
    mask[selected_global_indices] = True

    param_blocks = []
    current_idx = 0
    for name in model_param_names:
        param_size = net_glob.state_dict()[name].numel()
        param_blocks.append((current_idx, current_idx + param_size))
        current_idx += param_size

    client_param_indices = []
    client_base_size = global_size // num_clients
    for i in range(num_clients):
        start = i * client_base_size
        end = start + client_base_size if i != num_clients -1 else global_size
        client_param_indices.append((start, end))
    client_start, client_end = client_param_indices[client_id]

    mask_dict = {}
    for param_idx, param_name in enumerate(model_param_names):
        param = net_glob.state_dict()[param_name]
        param_size = param.numel()
        param_global_start, param_global_end = param_blocks[param_idx]

        overlap_start = max(param_global_start, client_start)
        overlap_end = min(param_global_end, client_end)
        if overlap_start >= overlap_end:
            continue

        param_selected = selected_global_indices[(selected_global_indices >= param_global_start) & 
                                                (selected_global_indices < param_global_end)]
        param_local_indices = param_selected - param_global_start

        in_client_block = (param_selected >= client_start) & (param_selected < client_end)
        valid_indices = param_local_indices[in_client_block]

        param_mask = torch.zeros(param_size, device=device)
        if len(valid_indices) > 0:
            param_mask[valid_indices] = 1.0
        mask_dict[param_name] = param_mask.reshape(param.shape)

    return mask_dict

class DomainClientUpdate_avg_mask:
    def __init__(self, args, train_loader, mask):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.mask = mask

    def train(self, net):
        # model_param_names = [name for name, _ in net.named_parameters()]

        # with torch.no_grad():
        #     for name, param in net.named_parameters():
        #         if name in self.mask:
        #             mask = self.mask[name].to(param.device)

        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )

        net.train()
        total_loss = 0.0
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                data, target = batch[:2]
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    for name, param in net.named_parameters():
                        if name in self.mask:
                            mask = self.mask[name].to(param.device)
                            param.data.mul_(mask)
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep

class DomainClientUpdate_avg_mask2:
    def __init__(self, args, train_loader, delta_history, model_param_names, net_glob, global_size, num_clients, client_id):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.delta_history = delta_history
        self.model_param_names = model_param_names
        self.net_glob = net_glob
        self.global_size = global_size
        self.num_clients = num_clients
        self.client_id = client_id
        self.std = args.tg_std

    def train(self, net):
        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )

        net.train()
        total_loss = 0.0

        for _ in range(self.args.local_ep):
            current_mask_ratio = generate_truncated_normal_random_numbers(
                mean=0.5, 
                std_dev=self.std,
                size=1
            )[0]

            current_mask = generate_federated_mask(
                delta_history=self.delta_history,
                mask_ratio=current_mask_ratio,
                model_param_names=self.model_param_names,
                net_glob=self.net_glob,
                global_size=self.global_size,
                num_clients=self.num_clients,
                client_id=self.client_id,
                device=self.args.device
            )

            epoch_loss = 0.0
            for batch in self.train_loader:
                data, target = batch[:2]
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    for name, param in net.named_parameters():
                        if name in current_mask:
                            mask = current_mask[name].to(param.device)
                            param.data.mul_(mask)
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss

        return copy.deepcopy(net), total_loss / self.args.local_ep
    

class DomainClientUpdate_avg_nlp:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()

    def train(self, net):
        optimizer = torch.optim.SGD(
            net.parameters(), 
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep_nlp):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)
                
                optimizer.step()
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    

    
class DomainClientUpdate_avg_kl:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()

    def train(self, net):
        global_model = copy.deepcopy(net)
        for p in global_model.parameters():
            p.requires_grad = False
        global_model = global_model.to(self.args.device)

        optimizer = torch.optim.SGD(
            net.parameters(), 
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss_ce = self.criterion(logits, target)
                
                param_diff = torch.tensor(0., device=self.args.device)
                for p_local, p_global in zip(net.parameters(), global_model.parameters()):
                    param_diff += torch.norm(p_local - p_global)
                loss_param = self.args.alpha * param_diff
                
                with torch.no_grad():
                    global_logits = global_model(data)
                loss_kl = self.args.beta * F.kl_div(
                    input=F.log_softmax(output, dim=1),
                    target=F.softmax(global_logits, dim=1),
                    reduction='batchmean'
                )
                
                total_loss = loss_ce + loss_param + loss_kl
                total_loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)
                
                optimizer.step()
                
                epoch_loss += total_loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    

class DomainClientUpdate_avgpBN:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.bn_params = None

    def train(self, net):
        non_bn_params = [p for n, p in net.named_parameters() if not ('bn' in n or 'batch_norm' in n)]
        optimizer = torch.optim.SGD(
            non_bn_params,
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                else:
                    data, target = batch[:2]
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)
                
                optimizer.step()
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss

        self.bn_params = {
            k: v.clone().detach()
            for k, v in net.state_dict().items()
            if 'bn' in k or 'batch_norm' in k
        }

        return copy.deepcopy(net), total_loss / self.args.local_ep, self.bn_params

class DomainClientUpdate_frzavg:
    def __init__(self, args, train_loader, freeze_layers=None):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.freeze_layers = freeze_layers 

    def train(self, net):
        if self.freeze_layers:
            net = freeze_layers(net, self.freeze_layers)
        
        trainable_params = [p for p in net.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(
            trainable_params,
            lr=self.args.frzulr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(
                        trainable_params,
                        self.args.clip_grad
                    )
                
                optimizer.step()
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    
class DomainClientUpdate_frzavg_lr:
    def __init__(self, args, train_loader, freeze_layers=None):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.freeze_layers = freeze_layers 

    def train(self, net):
        global_model = copy.deepcopy(net)
        for p in global_model.parameters():
            p.requires_grad = False
        global_model.to(self.args.device)

        if self.freeze_layers:
            net = freeze_layers(net, self.freeze_layers)
        
        trainable_params = [p for p in net.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(
            trainable_params,
            lr=self.args.frzulr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss_ce = self.criterion(logits, target)
                
                param_diff = torch.tensor(0., device=self.args.device)
                for p_local, p_global in zip(net.parameters(), global_model.parameters()):
                    if p_local.requires_grad:
                        param_diff += torch.norm(p_local - p_global.detach())
                loss_param = self.args.alpha * param_diff
                
                with torch.no_grad():
                    global_logits = global_model(data)
                prob_local = F.log_softmax(output, dim=1)
                prob_global = F.softmax(global_logits, dim=1)
                loss_kl = self.args.beta * F.kl_div(
                    input=prob_local,
                    target=prob_global,
                    reduction='batchmean'
                )
                
                total_loss = loss_ce + loss_param + loss_kl
                total_loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(
                        trainable_params,
                        self.args.clip_grad
                    )
                
                optimizer.step()
                
                epoch_loss += total_loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    
class DomainClientUpdate_frzavg_ulr:
    def __init__(self, args, train_loader, freeze_layers=None):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.freeze_layers = freeze_layers 

    def train(self, net):
        global_model = copy.deepcopy(net)
        for p in global_model.parameters():
            p.requires_grad = False
        global_model.to(self.args.device)

        if self.freeze_layers:
            net = freeze_layers(net, self.freeze_layers)
        
        trainable_params = [p for p in net.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(
            trainable_params,
            lr=self.args.frzulr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss_ce = self.criterion(logits, target)
                
                param_diff = torch.tensor(0., device=self.args.device)
                for p_local, p_global in zip(net.parameters(), global_model.parameters()):
                    if p_local.requires_grad:
                        param_diff += torch.norm(p_local - p_global.detach())
                loss_param = self.args.alpha * param_diff
                
                with torch.no_grad():
                    global_logits = global_model(data)
                prob_local = F.log_softmax(output, dim=1)
                prob_global = F.softmax(global_logits, dim=1)
                loss_kl = self.args.beta * F.kl_div(
                    input=prob_local,
                    target=prob_global,
                    reduction='batchmean'
                )
                
                total_loss = loss_ce + 1/loss_param + 1/loss_kl
                total_loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(
                        trainable_params,
                        self.args.clip_grad
                    )
                
                optimizer.step()
                
                epoch_loss += total_loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep

class DomainClientUpdate_avg_freeze:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = torch.nn.CrossEntropyLoss()

    def train(self, net):
        optimizer = optim.SGD(
            net.parameters(), 
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )

        net.train()
        total_loss = 0.0

        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)

                optimizer.zero_grad()
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                loss.backward()

                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)

                optimizer.step()

                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep
    

class DomainClientUpdate_avg_sal:
    def __init__(self, args, train_loader):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()

    def train(self, net, global_model=None):
        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=self.args.lr,
            momentum=0.9,
            weight_decay=1e-4
        )

        net.train()
        total_loss = 0.0
        initial_weights = copy.deepcopy(net.state_dict())
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)

                optimizer.zero_grad()

                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)

                loss.backward()

                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)

                optimizer.step()

                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss

        delta = {k: (net.state_dict()[k] - global_model.state_dict()[k]).float() for k in net.state_dict()}
        importance = sum(torch.norm(delta[k]) for k in delta)

        global_norm = sum(torch.norm(param).item() for param in global_model.parameters())
        normalized_importance = importance / (global_norm)

        normalized_importance = min(1.0, max(0.0, normalized_importance))

        return copy.deepcopy(net), total_loss / self.args.local_ep, normalized_importance

class DomainClientUpdate_frzavgpBN:
    def __init__(self, args, train_loader, freeze_layers=None):
        self.args = args
        self.train_loader = train_loader
        self.criterion = nn.CrossEntropyLoss()
        self.freeze_layers = freeze_layers 

    def train(self, net):
        if self.freeze_layers:
            for name, param in net.named_parameters():
                if any(f in name for f in self.freeze_layers):
                    param.requires_grad = False
                if 'bn' in name or 'batch_norm' in name:
                    param.requires_grad = True
        
        optimizer = torch.optim.SGD(
            net.parameters(), 
            lr=self.args.frzulr,
            momentum=0.9,
            weight_decay=1e-4
        )
        
        net.train()
        total_loss = 0.0
        
        for _ in range(self.args.local_ep):
            epoch_loss = 0.0
            for batch in self.train_loader:
                if len(batch) == 3:
                    data, target, _ = batch
                elif len(batch) == 2:
                    data, target = batch
                else:
                    raise ValueError(f"Unexpected batch format with {len(batch)} elements")
                
                data, target = data.to(self.args.device), target.to(self.args.device)
                
                optimizer.zero_grad()
                
                output = net(data)
                logits = extract_logits(output)
                loss = self.criterion(logits, target)
                
                loss.backward()
                
                if hasattr(self.args, 'clip_grad'):
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.args.clip_grad)
                
                optimizer.step()
                
                epoch_loss += loss.item()

            avg_epoch_loss = epoch_loss / len(self.train_loader)
            total_loss += avg_epoch_loss
        
        return copy.deepcopy(net), total_loss / self.args.local_ep


class DomainClientUpdate_record_fpl_org(object):
    def __init__(self, args,train_loader,test_loader,example_stats_local,client):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.client = client

        if self.args.unlearning_client == self.client  and args.baseline == 5:
            # print("rl",self.ldr_train.dataset.labels,len(self.ldr_train.dataset.labels))
            print("rl")
            self.ldr_train = copy.deepcopy(train_loader)
            self.ldr_train.dataset.labels = np.random.randint(0, self.args.num_classes,self.ldr_train.dataset.labels.shape).astype(np.int64).squeeze()
        else:
            self.ldr_train =train_loader
        
        self.ldr_test =test_loader
        self.example_stats_local = example_stats_local

    def hierarchical_info_loss(self, f_now, label, all_f, mean_f, all_global_protos_keys):

        f_pos = all_f[all_global_protos_keys == label.item()][0].to(self.args.device)
        f_neg = torch.cat(all_f[all_global_protos_keys != label.item()].tolist()).to(self.args.device)
        xi_info_loss = self.calculate_infonce(f_now, f_pos, f_neg)
        mean_f_pos =  mean_f[all_global_protos_keys == label.item()][0].to(self.args.device)
        mean_f_pos = mean_f_pos.view(1, -1)
        # mean_f_neg = torch.cat(list(np.array(mean_f)[all_global_protos_keys != label.item()]), dim=0).to(self.device)
        # mean_f_neg = mean_f_neg.view(9, -1)

        loss_mse = nn.MSELoss()

        cu_info_loss = loss_mse(f_now.to(torch.float32), mean_f_pos.to(torch.float32))

        hierar_info_loss = xi_info_loss + cu_info_loss
        return hierar_info_loss

    def calculate_infonce(self, f_now, f_pos, f_neg):
        f_proto = torch.cat((f_pos, f_neg), dim=0)
        l = torch.cosine_similarity(f_now, f_proto, dim=1)
        l = l / self.args.infoNCET

        exp_l = torch.exp(l)
        exp_l = exp_l.view(1, -1)
        pos_mask = [1 for _ in range(f_pos.shape[0])] + [0 for _ in range(f_neg.shape[0])]
        pos_mask = torch.tensor(pos_mask, dtype=torch.float32).to(self.args.device)
        pos_mask = pos_mask.view(1, -1)
        # pos_l = torch.einsum('nc,ck->nk', [exp_l, pos_mask])
        pos_l = exp_l * pos_mask
        sum_pos_l = pos_l.sum(1)
        sum_exp_l = exp_l.sum(1)
        infonce_loss = -torch.log(sum_pos_l / sum_exp_l)
        return infonce_loss

    def agg_func(self,protos):
        """
        Returns the average of the weights.
        """

        for [label, proto_list] in protos.items():
            if len(proto_list) > 1:
                proto = 0 * proto_list[0].data
                for i in proto_list:
                    proto += i.data
                protos[label] = proto / len(proto_list)
            else:
                protos[label] = proto_list[0]

        return protos

    def train(self, net,global_protos={},attack_alpha=1,atk_eps=0.1,tgtmodel=None,atk_target=9):
        net.train()

        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        #optimizer = ada_hessain.AdaHessian(net.parameters())

        if len(global_protos) != 0:
            all_global_protos_keys = np.array(list(global_protos.keys()))
            all_f = []
            mean_f = []

            for protos_key in all_global_protos_keys:
                temp_f = global_protos[protos_key]
                temp_f = torch.cat(temp_f, dim=0).to(self.args.device)
                all_f.append(temp_f.cpu())
                mean_f.append(torch.mean(temp_f, dim=0).cpu())
            # all_f  [10, clust_vary, 2048 ]
            # mean_f [10,2048]
            all_f_np = np.empty(len(all_f), dtype=object)
            mean_f_np = np.empty(len(mean_f), dtype=object)
            for i in range(len(all_f)):
                all_f_np[i] = all_f[i].detach()
            for i in range(len(mean_f)):
                mean_f_np[i] = mean_f[i].detach()
            all_f = all_f_np
            mean_f = mean_f_np

        if attack_alpha < 1:
            representative_sample = self.example_stats_local[0][self.args.backdoor_client_idx]
            print(representative_sample)
            exit()


        epoch_loss = []
        loss_list =[]
        for iter in range(self.args.local_ep):
            net.train()
            torch.cuda.empty_cache()
            batch_loss = []
            correct = 0
            correct_poison=0
            agg_protos_label = {}
            for batch_idx, (images, labels,idx) in enumerate(self.ldr_train):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()

                log_probs , feature1 = net(images, output_feature = True)
                pred = log_probs.data.max(1)[1]
                correct += pred.eq(labels.data.view_as(pred)).cpu().sum().item()
                lossCE = self.loss_func(log_probs, labels)

                if len(global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None
                    for label in labels:
                        if label.item() in global_protos.keys():
                            f_now = feature1[i].unsqueeze(0)
                            loss_instance = self.hierarchical_info_loss(f_now, label, all_f, mean_f, all_global_protos_keys).to(torch.float32)
                            if loss_InfoNCE is None:
                                loss_InfoNCE = loss_instance
                            else:
                                loss_InfoNCE += loss_instance
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i

                loss_InfoNCE = loss_InfoNCE

                if attack_alpha == 1:
                    loss = lossCE + loss_InfoNCE
                elif attack_alpha < 1:
                    tgtmodel.eval()
                    # poison_images, poison_targets = poison_images.to(device), poison_targets.to(device)
                    with torch.no_grad():
                        target_transform = lambda x: all2one_target_transform(x, atk_target)
                        poison_images = images
                        poison_targets = labels
                        noise = tgtmodel(poison_images) * atk_eps
                        atkdata = clip_image(poison_images + noise)
                        atktarget = target_transform(poison_targets)
                        if self.args.attack_portion < 1.0:
                            atkdata = atkdata[:int(attack_portion * bs)]
                            atktarget = atktarget[:int(attack_portion * bs)]

                    atkoutput,feature2 = net(atkdata.detach(),  output_feature = True)
                    poison_logits = extract_logits(atkoutput)
                    poison_pred = poison_logits.data.max(1)[1]
                    loss_poison = F.cross_entropy(poison_logits, atktarget.detach())
                    correct_poison += poison_pred.eq(atktarget.data.view_as(poison_pred)).cpu().sum().item()

                    loss = lossCE  * attack_alpha +  (1.0 - attack_alpha) * loss_poison + F.cosine_similarity(feature1, feature2).sum()  + loss_InfoNCE

                loss.backward()
                batch_loss.append(loss.item())
                optimizer.step()

                if iter == self.args.local_ep - 1:
                    for i in range(len(labels)):
                        if labels[i].item() in agg_protos_label:
                            agg_protos_label[labels[i].item()].append(feature1[i, :])
                        else:
                            agg_protos_label[labels[i].item()] = [feature1[i, :]]

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))


            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            


            if iter == 0 and self.args.record_forget_event is True and self.example_stats_local is not  None:
                
                train_loss, train_acc,self.example_stats_local[0][self.client] = test(args=self.args, test_loader = self.ldr_train,net=net,example_stats=(self.example_stats_local)[0][self.client])              
                test_loss, test_acc,self.example_stats_local[1][self.client] = test(args=self.args, test_loader = self.ldr_test,net=net,example_stats=(self.example_stats_local)[1][self.client])

            if self.args.unlearning_client == self.client:
                print(f"client {self.client}.......Break")
                break

        agg_protos = self.agg_func(agg_protos_label)

        # self.local_protos[index] = agg_protos
        if attack_alpha < 1.0:
            print(f" Clean acc {100.0 * (float(correct) / float(len(self.ldr_train.dataset)))}.4f  | "
                  f" Marker acc {100.0 * (float(correct_poison) / float(len(self.ldr_train.dataset)))}")

        return net, net.state_dict(), agg_protos,sum(epoch_loss) / len(epoch_loss), self.example_stats_local


class DomainClientUpdate(object):
    def __init__(self, args, train_loader):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = train_loader

    def train(self, net, global_protos={}, mask =None):
        net.train()

        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        # optimizer = ada_hessain.AdaHessian(net.parameters())

        if len(global_protos) != 0:
            all_global_protos_keys = np.array(list(global_protos.keys()))
            all_f = []
            mean_f = []

            for protos_key in all_global_protos_keys:
                temp_f = global_protos[protos_key]
                temp_f = torch.cat(temp_f, dim=0).to(self.args.device)
                all_f.append(temp_f.cpu())
                mean_f.append(torch.mean(temp_f, dim=0).cpu())
            # all_f  [10, clust_vary, 2048 ]
            # mean_f [10,2048]
            all_f_np = np.empty(len(all_f), dtype=object)
            mean_f_np = np.empty(len(mean_f), dtype=object)
            for i in range(len(all_f)):
                all_f_np[i] = all_f[i].detach()
            for i in range(len(mean_f)):
                mean_f_np[i] = mean_f[i].detach()
            all_f = all_f_np
            mean_f = mean_f_np

        epoch_loss = []
        loss_list = []
        for iter in range(self.args.local_ep):
            net.train()
            torch.cuda.empty_cache()
            batch_loss = []
            correct = 0
            correct_poison = 0
            agg_protos_label = {}
            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()

                log_probs, feature1 = net(images, output_feature=True)
                pred = log_probs.data.max(1)[1]
                correct += pred.eq(labels.data.view_as(pred)).cpu().sum().item()
                lossCE = self.loss_func(log_probs, labels)

                if len(global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None
                    for label in labels:
                        if label.item() in global_protos.keys():
                            f_now = feature1[i].unsqueeze(0)
                            loss_instance = hierarchical_info_loss(self.args,f_now, label, all_f, mean_f,all_global_protos_keys).to(torch.float32)
                            if loss_InfoNCE is None:
                                loss_InfoNCE = loss_instance
                            else:
                                loss_InfoNCE += loss_instance
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i

                loss_InfoNCE = loss_InfoNCE
                loss = lossCE + loss_InfoNCE
 
                loss.backward()
                batch_loss.append(loss.item())
                if mask != None:
                    for name, param in net.named_parameters():
                        if param.grad is not None:
                            param.grad *= mask[name]
  
                optimizer.step()

                if iter == self.args.local_ep - 1:
                    for i in range(len(labels)):
                        if labels[i].item() in agg_protos_label:
                            agg_protos_label[labels[i].item()].append(feature1[i, :])
                        else:
                            agg_protos_label[labels[i].item()] = [feature1[i, :]]

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                              100. * batch_idx / len(self.ldr_train), loss.item()))

            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        agg_protos = agg_func(agg_protos_label)

        # self.local_protos[index] = agg_protos
        result = [sum(epoch_loss) / len(epoch_loss) ]

        return net, net.state_dict(), agg_protos, result
    

class DomainClientUnlearningBl3frzpBN:
    def __init__(self, args, train_loader, threshold, freeze_layers=None, bn_params_path=None):
        self.args = args
        self.loss_func = torch.nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = train_loader
        self.threshold = threshold
        self.distance_threshold = 20
        self.num_updates_in_epoch = None
        self.freeze_layers = freeze_layers
        self.bn_params = None
        if bn_params_path:
            self.load_bn_params(bn_params_path)

    def load_bn_params(self, path):
        if os.path.exists(path):
            self.bn_params = torch.load(path)
            self.bn_params = {k.replace('module.', ''): v for k, v in self.bn_params.items()}

    def save_bn_params(self, net, path):
        bn_params = {k.replace('module.', ''): v.cpu().clone() for k, v in net.state_dict().items() if 'bn' in k}
        torch.save(bn_params, path)

    def _freeze_layers(self, net):
        if not self.freeze_layers:
            return net
        for name, param in net.named_parameters():
            if any(f in name for f in self.freeze_layers):
                param.requires_grad = False
            if 'bn' in name or 'batch_norm' in name:
                param.requires_grad = True
        return net

    def train(self, net, net_ref, net_unlearning_client):
        net = self._freeze_layers(net)
        
        if self.bn_params:
            current_state = net.state_dict()
            current_state.update(self.bn_params)
            net.load_state_dict(current_state)

        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, net.parameters()),
            lr=self.args.frzulr,
            momentum=0.9
        )

        net.train()
        flag = False
        clip_grad = self.args.clip_grad
        epoch_loss = []

        for iter in range(self.args.num_local_unlearn_epochs):
            if flag:
                break
            batch_loss = []

            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)
                loss_joint = -loss
                loss_joint.backward()
                
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(net.parameters(), clip_grad)

                optimizer.step()

                with torch.no_grad():
                    distance = baseline1_utils.get_distance(net, net_ref)
                    if distance > self.threshold:
                        dist_vec = nn.utils.parameters_to_vector(net.parameters()) - nn.utils.parameters_to_vector(net_ref.parameters())
                        dist_vec = dist_vec / torch.norm(dist_vec) * np.sqrt(self.threshold)
                        proj_vec = nn.utils.parameters_to_vector(net_ref.parameters()) + dist_vec
                        nn.utils.vector_to_parameters(proj_vec, net.parameters())

                distance_ref_unlearning_client = baseline1_utils.get_distance(net, net_unlearning_client)
                if distance_ref_unlearning_client > self.distance_threshold:
                    flag = True
                    break

                if self.num_updates_in_epoch is not None and batch_idx >= self.num_updates_in_epoch:
                    break

                batch_loss.append(loss.item())

        if self.args.bn_save_dir:
            os.makedirs(self.args.bn_save_dir, exist_ok=True)
            save_path = os.path.join(self.args.bn_save_dir, f'client_{self.args.client_id}_bn_params.pth')
            self.save_bn_params(net, save_path)
        
        return {k: v for k, v in net.state_dict().items() if 'bn' not in k}

class DomainClientUpdate_Inject(object):
    def __init__(self, args, clean_loader, marker_loader, example_stats=None, client=-1):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.client = client
            
        self.ldr_clean = clean_loader
        self.ldr_marker = marker_loader
        self.example_stats = example_stats


    def train(self, net, global_protos={}, attack_alpha=1, atk_eps=0.1, tgtmodel=None, atk_target=9):
        net.train()

        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        # optimizer = ada_hessain.AdaHessian(net.parameters())

        if len(global_protos) != 0:
            all_global_protos_keys = np.array(list(global_protos.keys()))
            all_f = []
            mean_f = []

            for protos_key in all_global_protos_keys:
                temp_f = global_protos[protos_key]
                temp_f = torch.cat(temp_f, dim=0).to(self.args.device)
                all_f.append(temp_f.cpu())
                mean_f.append(torch.mean(temp_f, dim=0).cpu())
            # all_f  [10, clust_vary, 2048 ]
            # mean_f [10,2048]
            all_f_np = np.empty(len(all_f), dtype=object)
            mean_f_np = np.empty(len(mean_f), dtype=object)
            for i in range(len(all_f)):
                all_f_np[i] = all_f[i].detach()
            for i in range(len(mean_f)):
                mean_f_np[i] = mean_f[i].detach()
            all_f = all_f_np
            mean_f = mean_f_np

        epoch_loss = []
        loss_list = []
        agg_protos_label = {}
        for iter in range(self.args.local_ep):
            net.train()
            torch.cuda.empty_cache()
            batch_loss = []
            correct = 0
            correct_poison = 0

            for batch_idx, (images, labels, idx) in enumerate(self.ldr_clean):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()

                log_probs, feature1 = net(images, output_feature=True)
                pred = log_probs.data.max(1)[1]
                correct += pred.eq(labels.data.view_as(pred)).cpu().sum().item()
                lossCE = self.loss_func(log_probs, labels)

                if len(global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None
                    for label in labels:
                        if label.item() in global_protos.keys():
                            f_now = feature1[i].unsqueeze(0)
                            loss_instance = hierarchical_info_loss(self.args,f_now, label, all_f, mean_f,all_global_protos_keys).to(torch.float32)
                            if loss_InfoNCE is None:
                                loss_InfoNCE = loss_instance
                            else:
                                loss_InfoNCE += loss_instance
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i

                loss_InfoNCE = loss_InfoNCE

                loss = lossCE + loss_InfoNCE
                loss.backward()
                batch_loss.append(loss.item())
                optimizer.step()


                if iter == self.args.local_ep - 1:
                    for i in range(len(labels)):
                        if labels[i].item() in agg_protos_label:
                            agg_protos_label[labels[i].item()].append(feature1[i, :])
                        else:
                            agg_protos_label[labels[i].item()] = [feature1[i, :]]

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                              100. * batch_idx / len(self.ldr_train), loss.item()))


            for batch_idx, (images, labels, idx) in enumerate(self.ldr_marker):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs, feature1 = net(images, output_feature=True)
                pred = log_probs.data.max(1)[1]
                correct += pred.eq(labels.data.view_as(pred)).cpu().sum().item()
                lossCE = self.loss_func(log_probs, labels)

                tgtmodel.eval()
                # poison_images, poison_targets = poison_images.to(device), poison_targets.to(device)
                with torch.no_grad():
                    target_transform = lambda x: all2one_target_transform(x, atk_target)
                    poison_images = images
                    poison_targets = labels
                    noise = tgtmodel(poison_images) * atk_eps
                    atkdata = clip_image(poison_images + noise)
                    atktarget = target_transform(poison_targets)

                    # print(len(atkdata))
                atkoutput, feature2 = net(atkdata.detach(), output_feature=True)
                atk_logits = extract_logits(atkoutput)
                poison_pred = atk_logits.data.max(1)[1]
                # loss_poison = F.cross_entropy(atkoutput, atktarget.detach())
                loss_poison = self.loss_func(atk_logits, atktarget)
                correct_poison += poison_pred.eq(atktarget.data.view_as(poison_pred)).cpu().sum().item()

                if len(global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None
                    for label in labels:
                        f_orig = feature1[i].unsqueeze(0)
                        # f_marker = feature2[i].unsqueeze(0)
                        loss_instance1 = hierarchical_info_loss(self.args, f_orig, label, all_f, mean_f,all_global_protos_keys).to(torch.float32)
                        # loss_instance2 = hierarchical_info_loss(self.args, f_marker,  atktarget[label], all_f, mean_f,all_global_protos_keys).to(torch.float32) 
                        if loss_InfoNCE is None:
                            loss_InfoNCE = loss_instance1
                        else:
                            loss_InfoNCE += loss_instance1 
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i

                loss_InfoNCE = loss_InfoNCE


                loss = attack_alpha * lossCE  + (1.0 - attack_alpha) * loss_poison - F.cosine_similarity(feature1,feature2, dim=1).mean() + loss_InfoNCE

                loss.backward()
                batch_loss.append(loss.item())
                optimizer.step()

                if iter == self.args.local_ep - 1:
                    for i in range(len(labels)):
                        if labels[i].item() in agg_protos_label:
                            agg_protos_label[labels[i].item()].append(feature1[i, :])
                        else:
                            agg_protos_label[labels[i].item()] = [feature1[i, :]]

                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_marker.dataset),
                              100. * batch_idx / len(self.ldr_marker), loss.item()))
            
            epoch_loss.append(sum(batch_loss) / len(batch_loss))


        agg_protos = agg_func(agg_protos_label)

        # self.local_protos[index] = agg_protos
        result = [sum(epoch_loss) / len(epoch_loss) ]
        n_clean = float(len(self.ldr_clean.dataset))
        n_marker = float(len(self.ldr_marker.dataset))

        result.append( 100.0 * (float(correct) / (n_clean + n_marker ) ))
        result.append(100.0 * (float(correct_poison) / n_marker ))

        return net, net.state_dict(), agg_protos, result

class DomainClientUpdate_Hesian_record(object):
    def __init__(self, args, train_loader):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = train_loader
        
    def train(self, net):
        net.train()
        optimizer = ada_hessain.AdaHessian(net.parameters())
        epoch_loss = []
        
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels, _) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                optimizer.zero_grad()
                
                output = net(images)
                log_probs = extract_logits(output)
                loss = self.loss_func(log_probs, labels)
                
                grads = torch.autograd.grad(
                    loss, 
                    net.parameters(), 
                    create_graph=True,
                    retain_graph=True
                )
                for param, grad in zip(net.parameters(), grads):
                    param.grad = grad
                
                optimizer.step()
                
                # loss.backward(create_graph=True)
                # optimizer.step()
                # for param in net.parameters():
                #     param.grad = None

                batch_loss.append(loss.item())
                
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
    

def atk_train(net, atkmodel,adv_train_epo,atk_eps,ldr_train,device,atk_target,mask=None):

    loss_fn = nn.CrossEntropyLoss()
    func_fn = loss_fn
    target_transform = lambda x: all2one_target_transform(x, atk_target)
    atkmodel_optimizer = optim.Adam(atkmodel.parameters(), lr=0.0001)
    for iter in range(adv_train_epo):

        correct = 0
        correct_poison = 0

        loss_list = []
        net.eval()
        atkmodel.train()

        for batch_idx, (data, target, idx) in enumerate(ldr_train):
            bs = data.size(0)
            data, target = data.to(device), target.to(device)

            noise = atkmodel(data) * atk_eps
            atkdata = clip_image(data + noise)
            atktarget = target_transform(target)

            atkoutput = net(atkdata)
            loss_p = func_fn(atkoutput, atktarget)
            loss2 = loss_p

            atkmodel_optimizer.zero_grad()
            loss2.backward()
            atkmodel_optimizer.step()
            pred = atkoutput.data.max(1)[1]  # get the index of the max log-probability
            correct_poison += pred.eq(atktarget.data.view_as(pred)).cpu().sum().item()
            loss_list.append(loss2.item())

            poison_acc = 100.0 * (float(correct_poison) / float(len(ldr_train)))

            training_avg_loss = sum(loss_list) / len(loss_list)


            # if batch_idx < 1 and iter >8:
            #     fig, axs = plt.subplots(1, 2, figsize=(18, 6))
            #     # fig, axs = plt.subplots(1, 4, figsize=(24, 8))
            #     axs = axs.flatten()
            #     dm = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()
            #     ds = torch.as_tensor([0.5, 0.5, 0.5])[:, None, None].cuda()
            #
            #     img = data[0].squeeze(0).mul_(ds).add_(dm).clamp_(0, 1)
            #     img = img.permute(1, 2, 0)
            #     axs[0].imshow(img.detach().cpu().numpy())
            #
            #     img2 = atkdata[0].squeeze(0).mul_(ds).add_(dm).clamp_(0, 1)
            #     img2 = img2.permute(1, 2, 0)
            #     axs[1].imshow(img2.detach().cpu().numpy())
            #
            #     plt.show()
        poison_acc = 100.0 * (float(correct_poison) / float(len(ldr_train.dataset)))




    return training_avg_loss,poison_acc,atkmodel




def unlearn_sal(args, net, unlearn_loader,rate):
    optimizer = torch.optim.SGD(net.parameters(),args.lr,momentum=args.momentum)
    loss_fn = nn.CrossEntropyLoss()

    gradients = {}
    for name, param in net.named_parameters():
        gradients[name] = 0

    net.eval()

    for i, (image, target,idx) in enumerate(unlearn_loader):
        image = image.to(args.device)
        target = target.to(args.device)

        # compute output
        output_clean = net(image)
        loss = - loss_fn(output_clean, target)

        optimizer.zero_grad()
        loss.backward()

        with torch.no_grad():
            for name, param in net.named_parameters():
                if param.grad is not None:
                    gradients[name] += param.grad.data

    with torch.no_grad():
        for name in gradients:
            gradients[name] = torch.abs_(gradients[name])

    threshold_list = [rate]

    for i in threshold_list:
        print(i)
        sorted_dict_positions = {}
        hard_dict = {}

        # Concatenate all tensors into a single tensor
        all_elements = - torch.cat([tensor.flatten() for tensor in gradients.values()])

        # Calculate the threshold index for the top 10% elements
        threshold_index = int(len(all_elements) * i)

        # Calculate positions of all elements
        positions = torch.argsort(all_elements)
        ranks = torch.argsort(positions)

        start_index = 0
        for key, tensor in gradients.items():
            num_elements = tensor.numel()
            # tensor_positions = positions[start_index: start_index + num_elements]
            tensor_ranks = ranks[start_index : start_index + num_elements]

            sorted_positions = tensor_ranks.reshape(tensor.shape)
            sorted_dict_positions[key] = sorted_positions

            # Set the corresponding elements to 1
            threshold_tensor = torch.zeros_like(tensor_ranks)
            threshold_tensor[tensor_ranks < threshold_index] = 1
            threshold_tensor = threshold_tensor.reshape(tensor.shape)
            hard_dict[key] = threshold_tensor
            start_index += num_elements
    return hard_dict


class MyUnlearn(object):
    def __init__(self, args, train_loader,sim_label,mean_f):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = train_loader

        self.sim_label=sim_label
        self.mean_f = mean_f


    def train(self, net, marker_model = None,mask =None):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        epoch_loss = []

        sim_label = self.sim_label
        for iter in range(self.args.local_ep):
            net.train()
            torch.cuda.empty_cache()
            batch_loss = []
            correct_orig = 0
            correct_unlearn = 0
            for batch_idx, (images, labels, idx) in enumerate(self.ldr_train):

                images, labels = images.to(self.args.device), labels.to(self.args.device)
                unlearn_labels = sim_label[labels].to(self.args.device)
                # marker_model.eval()
                # with torch.no_grad():
                #     unlearn_datas = images
                #     noise = marker_model(unlearn_datas) * self.args.atk_eps
                #     unlearn_images = clip_image(unlearn_datas + noise)
                net.zero_grad()

                log_probs_orig, feature_orig = net(images, output_feature=True)
                # log_probs_gen , feature_gen  = net(unlearn_images , output_feature=True)

                pred = log_probs_orig.data.max(1)[1]

                correct_orig += pred.eq(labels.data.view_as(pred)).cpu().sum().item()
                correct_unlearn += pred.eq(unlearn_labels.data.view_as(pred)).cpu().sum().item()

                lossCE = self.loss_func(log_probs_orig, unlearn_labels)

                i = 0
                loss_InfoNCE = None
                for idx in range(len(labels)):
                    # f_gen= feature_gen[i].unsqueeze(0)
                    f_orig = feature_orig[i].unsqueeze(0)
                    f_pos = self.mean_f[ unlearn_labels[idx].item()].to(self.args.device).unsqueeze(0)
                    f_neg = self.mean_f[ labels[idx].item()].to(self.args.device).unsqueeze(0)
                    # gen_info_loss  = calculate_infonce(self.args, f_gen,  f_pos, f_neg)
                    orig_info_loss = calculate_infonce(self.args, f_orig, f_pos, f_neg)

                    if loss_InfoNCE is None:
                        loss_InfoNCE =  orig_info_loss
                    else:
                        loss_InfoNCE += orig_info_loss
                    i += 1
                loss_InfoNCE = loss_InfoNCE / i

                loss = lossCE + loss_InfoNCE
                loss.backward()
                batch_loss.append(loss.item())

                if mask != None:
                    for name, param in net.named_parameters():
                        if param.grad is not None:
                            param.grad *= mask[name]

                optimizer.step()

            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        result = [sum(epoch_loss) / len(epoch_loss), 100 * correct_orig / float(len(self.ldr_train.dataset)), 100 *  correct_unlearn / float(len(self.ldr_train.dataset))]

        return net, net.state_dict(), result
