 
 

from cProfile import label
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import numpy as np
from util.util import shot_split
import copy
from util.losses import *

import sklearn.metrics as metrics
from util.losses import FocalLoss, Ratio_Cross_Entropy
from util.etf_methods import *
import matplotlib.pyplot as plt



 
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset, idxs):
        self.args = args
         
         
        self.ldr_train, self.ldr_test = self.train_test(dataset, list(idxs))

    def get_loss(self):
        if self.args.loss_type == 'CE':
            return nn.CrossEntropyLoss()
        elif self.args.loss_type == 'focal':
            return FocalLoss(gamma=1).cuda(self.args.gpu)

    def train_test(self, dataset, idxs):
         
        train = DataLoader(DatasetSplit(dataset, idxs),
                           batch_size=self.args.local_bs, shuffle=True)
        test = DataLoader(dataset, batch_size=128)
        return train, test

    def update_weights(self, net, seed, net_glob, epoch, mu=1, lr=None):
        net_glob = net_glob

        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                 
                     
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = self.get_loss()
                loss = criterion(logits, labels)

                if self.args.beta > 0:
                    if batch_idx > 0:
                        w_diff = torch.tensor(0.).to(self.args.device)
                        for w, w_t in zip(net_glob.parameters(), net.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2)
                        w_diff = torch.sqrt(w_diff)
                        loss += self.args.beta * mu * w_diff

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
             
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


    def update_weights_ditto(self, net, seed, net_glob, epoch, mu=0.01, lr=None):
        net_glob = net_glob

        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                 
                     
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = self.get_loss()
                loss = criterion(logits, labels)
                 
                if batch_idx > 0:
                    w_diff = torch.tensor(0.).to(self.args.device)
                    for w, w_t in zip(net_glob.parameters(), net.parameters()):
                        w_diff += torch.pow(torch.norm(w - w_t), 2)
                    w_diff = torch.sqrt(w_diff)
                    loss += mu * w_diff
                 
                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
             
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

    
    
    def update_weights_ratio_loss(self, net, seed, net_glob, epoch, mu=1, lr=None):
        net_glob = net_glob

        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                 
                     
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = Ratio_Cross_Entropy(args=self.args, class_num=100, alpha=None, size_average=True)
                loss = criterion(logits, labels)

                if self.args.beta > 0:
                    if batch_idx > 0:
                        w_diff = torch.tensor(0.).to(self.args.device)
                        for w, w_t in zip(net_glob.parameters(), net.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2)
                        w_diff = torch.sqrt(w_diff)
                        loss += self.args.beta * mu * w_diff

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
             
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

    def update_weights_backbone_only(self, net, seed, epoch, criterion=None, mu=1, lr=None):
         
        count = 0
        for p in net.parameters():
            if count >= 108:         
                p.requires_grad = True
            else:
                p.requires_grad = False
            count += 1

        filter(lambda p: p.requires_grad, net.parameters())
         
        net.train()
        if criterion is None:
            criterion = nn.CrossEntropyLoss()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=self.args.momentum)

        epoch_loss = []

         
        linear_weights = net.linear.weight


         
         

         
         
        spar_mask = torch.zeros_like(linear_weights)
         
         
        spar_mask[:25] = (linear_weights[:25] == 0).float()   


        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                loss = criterion(logits, labels)

                loss.backward()


                 
                net.linear.weight.grad *= spar_mask
                 

                 
                 
                 

                optimizer.step()

                batch_loss.append(loss.item())
                 
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
    

    def update_weights_fedrep(self, net, seed, net_glob, epoch, mu=1, lr=None):
        net_glob = net_glob

        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []

         
        count = 0
        for p in net.parameters():
            if count >= 108:         
                p.requires_grad = True
            else:
                p.requires_grad = False
            count += 1

        filter(lambda p: p.requires_grad, net.parameters())



        for iter in range(15):
            if iter == 10:
                 
                count = 0
                for p in net.parameters():
                    if count >= 108:         
                        p.requires_grad = False
                    else:
                        p.requires_grad = True
                    count += 1

                filter(lambda p: p.requires_grad, net.parameters())

            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                 
                     
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = self.get_loss()
                loss = criterion(logits, labels)

                if self.args.beta > 0:
                    if batch_idx > 0:
                        w_diff = torch.tensor(0.).to(self.args.device)
                        for w, w_t in zip(net_glob.parameters(), net.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2)
                        w_diff = torch.sqrt(w_diff)
                        loss += self.args.beta * mu * w_diff

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
             
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


    def update_weights_gaux(self, net, g_head, g_aux, l_head, epoch, mu=1, lr=None, loss_switch=None):
        net.train()
         
        optimizer_g_backbone = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
        optimizer_l_head = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
         

        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()

        def adaptive_angle_loss(features, labels):
            similarity_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=-1)
            diff_mask = labels.unsqueeze(1) != labels.unsqueeze(0)
            same_mask = labels.unsqueeze(1) == labels.unsqueeze(0)
            loss_diff = (similarity_matrix * diff_mask.float()).sum() / diff_mask.float().sum()
            loss_same = ((1 - similarity_matrix) * same_mask.float()).sum() / same_mask.float().sum()
            return loss_diff + loss_same

        def normalized_feature_loss(features):
             
            norms = torch.norm(features, dim=1)
             
            mean_norm = torch.mean(norms)
             
            variance = torch.mean((norms - mean_norm) ** 2)
             
            return variance

        def get_mma_loss(features, labels):
             
            weight_ = F.normalize(features, p=2, dim=1)
            cosine = torch.matmul(weight_, weight_.t())  
            same_mask = labels.unsqueeze(1) == labels.unsqueeze(0)   

             
            cosine = cosine - 2. * torch.diag(torch.diag(cosine))  
            cosine[same_mask] = -1

             
            loss = -torch.acos(cosine.max(dim=1)[0].clamp(-0.99999, 0.99999)).mean()

            return loss

        
        def balanced_softmax_loss(labels, logits, sample_per_class=None):
            """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
            Args:
            labels: A int tensor of size [batch].
            logits: A float tensor of size [batch, no_of_classes].
            sample_per_class: A int tensor of size [no of classes].
            reduction: string. One of "none", "mean", "sum"
            Returns:
            loss: A float tensor. Balanced Softmax Loss.
            """
            if sample_per_class is None:
                sample_per_class = [500, 477, 455, 434, 415, 396, 378, 361, 344, 328, 314, 299, 286, 273, 260, 248, 237, 226, 216, 206, 197, 188, 179, 171, 163, 156, 149, 142, 135, 129, 123, 118, 112, 107, 102, 98, 93, 89, 85, 81, 77, 74, 70, 67, 64, 61, 58, 56, 53, 51, 48, 46, 44, 42, 40, 38, 36, 35, 33, 32, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5]
            sample_per_class = torch.tensor(sample_per_class)
            spc = sample_per_class.type_as(logits)
            spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
            logits = logits + spc.log()
            loss = F.cross_entropy(input=logits, target=labels, reduction='mean')
            return loss

        
        if loss_switch == "focus_loss":
            criterion_l = focus_loss(num_classes=100)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g_backbone.zero_grad()
                optimizer_g_aux.zero_grad()
                optimizer_l_head.zero_grad()
                 

                 
                features = net(images, latent_output=True)

                 


                output_g_backbone = g_head(features)
            

                
                loss_g_backbone = criterion_g(output_g_backbone, labels)
                 
                 
                loss_g_backbone.backward()
                 
                 
                optimizer_g_backbone.step()
                
                 
                output_g_aux = g_aux(features.detach())
                loss_g_aux = criterion_l(output_g_aux, labels)
                loss_g_aux.backward()
                optimizer_g_aux.step()

                 
                output_l_head = l_head(features.detach())
                loss_l_head = criterion_l(output_l_head, labels)
                loss_l_head.backward()
                optimizer_l_head.step()

                loss = loss_g_backbone + loss_g_aux + loss_l_head
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss)





    def update_weights_class_mean(self, net, g_head, g_aux, l_head, epoch, class_means, mu=1, lr=None, loss_switch=None):
        net.train()
         
        optimizer_g_backbone = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
        optimizer_l_head = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
         

        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()

        def inner_min(features, labels, class_means):
            valid_class_means = [class_mean for class_mean in class_means.values() if class_mean is not None]
            if valid_class_means:
                global_mean = torch.stack(valid_class_means).mean(dim=0)
            del valid_class_means
             
             
            class_means = {class_idx: (class_mean - global_mean) if class_mean is not None else None for class_idx, class_mean in class_means.items()}
            features = features - global_mean

            loss = 0
            for i in range(len(features)):
                loss += F.cosine_similarity(features[i].unsqueeze(0), class_means[labels[i].item()].unsqueeze(0))
            return loss


        def inter_max(features, labels, class_means):
            valid_class_means = [class_mean for class_mean in class_means.values() if class_mean is not None]
            if valid_class_means:
                global_mean = torch.stack(valid_class_means).mean(dim=0)
            del valid_class_means
            class_means = {class_idx: (class_mean - global_mean) if class_mean is not None else None for class_idx, class_mean in class_means.items()}
            features = features - global_mean

             
             
            filtered_class_means = {k: v for k, v in class_means.items() if v is not None}
            all_class_means = torch.stack(list(filtered_class_means.values()))
             
            max_similarities = torch.zeros(len(features), device=features.device)


            for i in range(len(features)):
                current_label = labels[i].item()
                current_class_mean = class_means[current_label]

                 
                other_class_mean_indices = [idx for idx, k in enumerate(filtered_class_means.keys()) if k != current_label]

                 
                similarities = F.cosine_similarity(current_class_mean.unsqueeze(0), all_class_means[other_class_mean_indices], dim=1)

                 
                max_similarity, max_id = similarities.max(dim=0)

                 
                most_similar_class_mean = all_class_means[other_class_mean_indices[max_id]]
                
                 
                feature_similarity = F.cosine_similarity(features[i].unsqueeze(0), most_similar_class_mean.unsqueeze(0), dim=1)

                 
                max_similarities[i] = feature_similarity

             
            total_loss = max_similarities.sum()
            return total_loss



        def inter_max_feat_classmean(features, labels, class_means):
             
            max_similarities = torch.zeros(len(features), device=features.device)

             
            valid_class_means = [class_mean for class_mean in class_means.values() if class_mean is not None]
            if valid_class_means:
                global_mean = torch.stack(valid_class_means).mean(dim=0)
            del valid_class_means

             
            features = features - global_mean
            class_means = {class_idx: (class_mean - global_mean) if class_mean is not None else None for class_idx, class_mean in class_means.items()}

             
            filtered_class_means = {k: v for k, v in class_means.items() if v is not None}
            all_class_means = torch.stack(list(filtered_class_means.values()))

            for i in range(len(features)):
                current_label = labels[i].item()

                 
                other_class_mean_indices = [idx for idx, k in enumerate(filtered_class_means.keys()) if k != current_label]

                 
                similarities = F.cosine_similarity(features[i].unsqueeze(0), all_class_means[other_class_mean_indices], dim=1)

                 
                max_similarity, max_id = similarities.max(dim=0)

                 
                most_similar_class_mean = all_class_means[other_class_mean_indices[max_id]]

                 
                feature_similarity = F.cosine_similarity(features[i].unsqueeze(0), most_similar_class_mean.unsqueeze(0), dim=1)

                 
                max_similarities[i] = feature_similarity

             
            total_loss = max_similarities.sum()
            return total_loss


        def balanced_softmax_loss(labels, logits, sample_per_class=None):
            """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
            Args:
            labels: A int tensor of size [batch].
            logits: A float tensor of size [batch, no_of_classes].
            sample_per_class: A int tensor of size [no of classes].
            reduction: string. One of "none", "mean", "sum"
            Returns:
            loss: A float tensor. Balanced Softmax Loss.
            """
            if sample_per_class is None:
                sample_per_class = [500, 477, 455, 434, 415, 396, 378, 361, 344, 328, 314, 299, 286, 273, 260, 248, 237, 226, 216, 206, 197, 188, 179, 171, 163, 156, 149, 142, 135, 129, 123, 118, 112, 107, 102, 98, 93, 89, 85, 81, 77, 74, 70, 67, 64, 61, 58, 56, 53, 51, 48, 46, 44, 42, 40, 38, 36, 35, 33, 32, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5]
            sample_per_class = torch.tensor(sample_per_class)
            spc = sample_per_class.type_as(logits)
            spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
            logits = logits + spc.log()
            loss = F.cross_entropy(input=logits, target=labels, reduction='mean')
            return loss

        
        if loss_switch == "focus_loss":
            criterion_l = focus_loss(num_classes=100)

        epoch_loss = []
        momentum = 0.9   
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g_backbone.zero_grad()
                optimizer_g_aux.zero_grad()
                optimizer_l_head.zero_grad()
                 

                 
                features = net(images, latent_output=True)
                
                 
                for i in range(len(features)):
                    class_idx = labels[i].item()

                    if class_means[class_idx] is None:
                        class_means[class_idx] = features[i]
                    else:
                        class_means[class_idx] = momentum * class_means[class_idx] + (1 - momentum) * features[i]
                for key in class_means:
                    if class_means[key] != None:
                        class_means[key] = class_means[key].detach()

                 
                 

                 

                output_g_backbone = g_head(features)
            
                 
                 

                 

                loss_g_backbone = balanced_softmax_loss(labels, output_g_backbone) 
                loss_g_backbone.backward()


                 
                 
                optimizer_g_backbone.step()
                
                 
                output_g_aux = g_aux(features.detach())
                loss_g_aux = criterion_l(output_g_aux, labels)
                loss_g_aux.backward()
                optimizer_g_aux.step()

                 
                output_l_head = l_head(features.detach())
                loss_l_head = criterion_l(output_l_head, labels)
                loss_l_head.backward()
                optimizer_l_head.step()

                loss = loss_g_backbone + loss_g_aux + loss_l_head
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss), class_means
    

    def update_weights_etf(self, net, g_head, g_aux, l_head, epoch, mu=1, lr=None, loss_switch=None):
        net.train()
         
        optimizer_g_backbone = torch.optim.SGD([{"params": net.parameters()},
                                {"params": g_head.parameters()}], lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
        optimizer_l_head = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
         


        criterion_l = nn.CrossEntropyLoss()
            
        num_classes = 10
        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g_backbone.zero_grad()
                optimizer_g_aux.zero_grad()
                optimizer_l_head.zero_grad()
                 


                 
                learned_norm = produce_Ew(labels, num_classes)
                cur_M = learned_norm * g_head.ori_M


                 
                feat = net(images, latent_output=True)
                features = g_head(feat)

                output = torch.matmul(features, cur_M)
                with torch.no_grad():
                    feat_nograd = features.detach()
                    H_length = torch.clamp(torch.sqrt(torch.sum(feat_nograd ** 2, dim=1, keepdims=False)), 1e-8)
                loss_g_backbone = dot_loss(features, labels, cur_M, g_head, 'reg_dot_loss', H_length, reg_lam=0)
                 
                loss_g_backbone.backward()
                 
                 
                optimizer_g_backbone.step()
                
                 
                output_g_aux = g_aux(feat.detach())
                loss_g_aux = criterion_l(output_g_aux, labels)
                loss_g_aux.backward()
                optimizer_g_aux.step()

                 
                output_l_head = l_head(feat.detach())
                loss_l_head = criterion_l(output_l_head, labels)
                loss_l_head.backward()
                optimizer_l_head.step()

                loss = loss_g_backbone + loss_g_aux + loss_l_head
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss)
    

    def update_weights_auto_selective_ghead(self, net, g_head, g_aux, l_head, epoch, mu=1, lr=None, loss_switch=None):
        net.train()
         
        optimizer_g_backbone = torch.optim.SGD(list(net.parameters()) + [g_head.weights], lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
        optimizer_l_head = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
         

        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()
        if loss_switch == "focus_loss":
            criterion_l = focus_loss(num_classes=100)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g_backbone.zero_grad()
                optimizer_g_aux.zero_grad()
                optimizer_l_head.zero_grad()
                 

                 
                features = net(images, latent_output=True)
                output_g_backbone = g_head(features)
                loss_g_backbone = criterion_g(output_g_backbone, labels)
                loss_g_backbone.backward()
                 
                 
                optimizer_g_backbone.step()
                
                 
                output_g_aux = g_aux(features.detach())
                loss_g_aux = criterion_l(output_g_aux, labels)
                loss_g_aux.backward()
                optimizer_g_aux.step()

                 
                output_l_head = l_head(features.detach())
                loss_l_head = criterion_l(output_l_head, labels)
                loss_l_head.backward()
                optimizer_l_head.step()

                loss = loss_g_backbone + loss_g_aux + loss_l_head
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, g_head, l_head, sum(epoch_loss) / len(epoch_loss)
    
    def update_weights_fedrod(self, net, g_head, g_aux, l_head, epoch, mu=1, lr=None):
        net.train()
         
         
        optimizer_g_aux = torch.optim.SGD(list(net.parameters()) + list(g_aux.parameters()), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
        optimizer_l_head = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
         
         
         
         

        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g_aux.zero_grad()
                optimizer_l_head.zero_grad()

                 
                features = net(images, latent_output=True)
                output = g_aux(features)
                loss_g = criterion_g(output, labels)
                loss_g.backward()
                optimizer_g_aux.step()

                 
                output_l_head = l_head(features.detach())
                loss_l_head = criterion_l(output_l_head, labels)
                loss_l_head.backward()
                optimizer_l_head.step()

                loss = loss_g + loss_l_head
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss)
    
    def update_weights_balsoft_backup(self, net, g_head, g_aux, l_head, seed, net_glob, epoch, mu=1, lr=None):

        net.train()
         
        optimizer_g = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_l = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g.zero_grad()
                optimizer_l.zero_grad()
                optimizer_g_aux.zero_grad()

                 
                features = net(images, latent_output=True)
                output_g = g_head(features)
                loss_g = criterion_g(output_g, labels)
                loss_g.backward()
                optimizer_g.step()

                 
                output_g_aux = g_aux(features.detach())
                loss_g_aux = criterion_g(output_g_aux, labels)
                loss_g_aux.backward()
                optimizer_g_aux.step()

                 
                output_l = l_head(features.detach())
                loss_l = criterion_l(output_l, labels)
                loss_l.backward()
                optimizer_l.step()

                loss = loss_g + loss_g_aux + loss_l
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss)
    

    def update_weights_phead_backup(self, net, g_aux, g_head, l_head, seed, net_glob, epoch, mu=1, lr=None):

        net.train()
         
        optimizer_g = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_l = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
         
        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                optimizer_g.zero_grad()
                optimizer_l.zero_grad()
                optimizer_g_aux.zero_grad()

                 
                features = net(images, latent_output=True)
                 
                 
                 
                 

                 
                 
                 
                 
                 

                 
                output_l = l_head(features.detach())
                loss_l = criterion_l(output_l, labels)
                loss_l.backward()
                optimizer_l.step()

                loss = loss_l
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), g_aux, l_head, sum(epoch_loss) / len(epoch_loss)
    

    def update_weights_unlearning(self, net, g_aux, g_head, l_head, seed, net_glob, epoch, mu=1, lr=None):

        net.train()
         
        optimizer_g = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_l = torch.optim.SGD(list(l_head.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(list(g_aux.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer = torch.optim.SGD(list(g_aux.parameters()) + list(l_head.parameters()), lr=self.args.lr, momentum=self.args.momentum)
         
        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()
        
        criterion_kl = nn.KLDivLoss(reduction='batchmean')
        criterion_ce = nn.CrossEntropyLoss()

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                 
                 
                optimizer_g_aux.zero_grad()
                 

                 
                features = net(images, latent_output=True)
                 
                 
                 
                 

                 
                 
                 
                 
                 

                 
                outputs_g_aux = g_aux(features.detach())
                outputs_l_head = l_head(features.detach())
                
                outputs_g_aux_normalized = F.normalize(outputs_g_aux, dim=1)
                outputs_l_head_normalized = F.normalize(outputs_l_head, dim=1)

                
                 
                loss_kl = criterion_kl(nn.functional.log_softmax(outputs_g_aux_normalized, dim=1),
                                    nn.functional.softmax(outputs_l_head_normalized, dim=1))
                loss = loss_kl
                loss.backward()
                optimizer_g_aux.step()


                 
                 
                 
                 
                 

                 
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), copy.deepcopy(g_aux), copy.deepcopy(l_head), sum(epoch_loss) / len(epoch_loss)
    
 
    def update_weights_norm_init(self, net, g_aux, g_head, l_head, seed, net_glob, epoch, mu=1, lr=None):

        net.train()
         


         
        norm = torch.norm(l_head.weight, p=2, dim=1)
         
        g_aux.weight = nn.Parameter(g_aux.weight * norm.unsqueeze(1))


        optimizer_g = torch.optim.SGD(list(net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_l = torch.optim.SGD(l_head.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        optimizer_g_aux = torch.optim.SGD(g_aux.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        optimizer = torch.optim.SGD(list(g_aux.parameters()) + list(l_head.parameters()), lr=self.args.lr, momentum=self.args.momentum)
         
        criterion_l = nn.CrossEntropyLoss()
        criterion_g = nn.CrossEntropyLoss()
        
         
        criterion_ce = nn.CrossEntropyLoss()

        epoch_loss = []


        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                 
                 
                optimizer_g_aux.zero_grad()
                 

                 
                features = net(images, latent_output=True)
                 
                 
                 
                 

                 
                 
                 
                 
                 

                 
                outputs_g_aux = g_aux(features.detach())
                 
                
                 
                 

                
                loss = criterion_ce(outputs_g_aux, labels)
                 
                 
                 
                loss.backward()
                optimizer_g_aux.step()
                


                 
                 
                 
                 
                 

                 
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), copy.deepcopy(g_aux), copy.deepcopy(l_head), sum(epoch_loss) / (len(epoch_loss) + 1e-10)
    
    def pfedme_update_weights(self, net, seed, net_glob, epoch, mu=1, lr=None):
        net_glob = net_glob

        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                 
                     
                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = self.get_loss()
                loss = criterion(logits, labels)

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
             
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

     
    def update_weights_gasp_grad(self, net, seed, net_glob, client_id, epoch, gradBag, mu=1, lr=None):
        hookObj,  gradAnalysor = gradBag.get_client(client_id)
        net_glob = net_glob
        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = self.get_loss()
                loss = criterion(logits, labels)
                hook_handle = logits.register_hook(
                    hookObj.hook_func_tensor)   
                if self.args.beta > 0:
                    if batch_idx > 0:
                        w_diff = torch.tensor(0.).to(self.args.device)
                        for w, w_t in zip(net_glob.parameters(), net.parameters()):
                            w_diff += torch.pow(torch.norm(w - w_t), 2)
                        w_diff = torch.sqrt(w_diff)
                        loss += self.args.beta * mu * w_diff
                loss.backward()

                if hookObj.has_gradient():
                     
                    gradAnalysor.update(
                        hookObj.get_gradient(), labels.cpu().numpy().tolist())
                optimizer.step()
                hook_handle.remove()
                batch_loss.append(loss.item())
                gradAnalysor.print_for_debug()

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        gradBag.load_in(client_id, hookObj, gradAnalysor)
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss), gradBag
 

    def update_weights_GBA_Loss(self, net, seed, epoch, pidloss, mu=1, lr=None):
        
        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = pidloss
                loss = criterion(logits, labels)

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


    def update_weights_GBA_Finetune(self, net, seed, epoch, pidloss, mu=1, lr=None):
         
        count = 0
        for p in net.parameters():
            if count >= 105:         
                break
            p.requires_grad = False
            count += 1

        filter(lambda p: p.requires_grad, net.parameters())
        net.train()
         
        if lr is None:
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, net.parameters()), lr=self.args.lr, momentum=self.args.momentum)
        else:
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, net.parameters()), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                labels = labels.long()
                net.zero_grad()
                logits = net(images)
                criterion = pidloss
                loss = criterion(logits, labels)

                loss.backward()
                optimizer.step()

                batch_loss.append(loss.item())
                 
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

     
    def update_weights_GBA_Layer(self, net, seed, epoch, GBA_Loss, GBA_Layer, mu=1, lr=None):

        net.train()
        GBA_Layer.train()
         
        if lr is None:
            backbone_optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        else:
            backbone_optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(epoch):
            batch_loss = []
             
             
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)

                labels = labels.long()
                 
                 

                 
                net.zero_grad()
                feat = net(images)
                logits = GBA_Layer(feat)
                loss = GBA_Loss(logits, labels) 
                loss.backward()
                backbone_optimizer.step()
                
                 
                 
                batch_loss.append(loss.item())

            epoch_loss.append(sum(batch_loss)/len(batch_loss))
         
        return net.state_dict(), GBA_Layer.state_dict(), sum(epoch_loss) / len(epoch_loss)
 
def globaltest_villina(net, test_dataset, args, dataset_class=None):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]

     
     
     
     

    predict_true_class = [0 for i in range(args.num_classes)]
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    acc_class_wise = [predict_true_class[i] / (total_class_label[i] + 1e-10) for i in range(args.num_classes)]
     
    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global


def globaltest(net, g_head, test_dataset, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
            if head_switch == True:
                outputs = g_head(features)
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    acc_class_wise = [predict_true_class[i] / total_class_label[i] for i in range(args.num_classes)]
    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global


def globaltest_calibra(net, g_aux, test_dataset, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]
    cali_alpha = torch.norm(g_aux.weight, dim=1)


     
     
    cali_alpha = torch.pow(cali_alpha, 1)
    inverse_cali_alpha = 1.7 / cali_alpha
     
    inverse_cali_alpha = inverse_cali_alpha.view(-1, 1)
    

     
    g_aux.weight = torch.nn.Parameter(g_aux.weight * inverse_cali_alpha)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
             
             
            if head_switch == True:
                outputs = g_aux(features)
                 
                 
                 

                 
                 

                 
                 
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    acc_class_wise = [predict_true_class[i] / total_class_label[i] for i in range(args.num_classes)]
    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global

def globaltest_classmean(net, g_head, test_dataset, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]

     
    class_sums = {}
    class_counts = {}
     
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
             
            if head_switch == True:
                outputs = g_head(features)
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
             
                 
            for i in range(images.size(0)):
                 
                label = labels[i].item()
                feature = features[i]

                 
                if label not in class_sums:
                    class_sums[label] = feature
                    class_counts[label] = 1
                else:
                    class_sums[label] += feature
                    class_counts[label] += 1
         

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    
     
    class_means = {label: class_sum / class_counts[label] for label, class_sum in class_sums.items()}
    class_norms = {label: torch.norm(mean, p=2) for label, mean in class_means.items()}
    acc_class_wise = [predict_true_class[i] / total_class_label[i] for i in range(args.num_classes)]
    angle = compute_angle(0, 1, class_means)
    
    tsne_switch = True
    if tsne_switch:
        import numpy as np
        import matplotlib.pyplot as plt
        from sklearn.manifold import TSNE
        from sklearn.preprocessing import StandardScaler

         
        features = np.array([value.cpu().numpy() for value in class_means.values()])

         
        features = StandardScaler().fit_transform(features)

         
        tsne = TSNE(n_components=2, random_state=42, learning_rate=200)
        features_2d = tsne.fit_transform(features)

         
        plt.figure(figsize=(10, 8))

        for i, class_label in enumerate(class_means.keys()):
            plt.scatter(features_2d[i, 0], features_2d[i, 1], label=str(class_label))

        plt.legend(loc='best')
        plt.xlabel('t-SNE component 1')
        plt.ylabel('t-SNE component 2')
        plt.title('t-SNE Visualization of Class Means')
     
        plt.savefig('t-SNE_Visualization_of_Class_Means.png', dpi=400)
        plt.savefig('t-SNE_Visualization_of_Class_Means.pdf', format='pdf')

     

     
    angle_matrix = torch.zeros(args.num_classes, args.num_classes)

     
    for i in range(args.num_classes):
        for j in range(i+1, args.num_classes):   
            angle = compute_angle(i, j, class_means)
            angle_matrix[i, j] = angle
            angle_matrix[j, i] = angle   

     
     
    plt.figure(figsize=(10, 10))
    plt.imshow(angle_matrix, cmap='hot', interpolation='nearest')

     
    plt.colorbar()

     
    plt.title('Angle Matrix')
    plt.xlabel('Class')
    plt.ylabel('Class')

     
    plt.savefig('angle_matrix.png')


    class_norms = {label: norm.cpu() for label, norm in class_norms.items()}
     
    labels = list(class_norms.keys())
    norms = list(class_norms.values())

     
    plt.figure(figsize=(10, 5))
    plt.bar(labels, norms)

     
    plt.title('L2 Norms of Class Means')
    plt.xlabel('Class')
    plt.ylabel('L2 Norm')

     
    plt.savefig('class_norms.png')




     
    class_norms = torch.norm(g_head.weight, dim=1)

     
    num_classes = g_head.weight.size(0)
    angle_matrix = torch.zeros(num_classes, num_classes)
    for i in range(num_classes):
        for j in range(i+1, num_classes):   
            angle = compute_angle(i, j, g_head.weight)
            angle_matrix[i, j] = angle
            angle_matrix[j, i] = angle   

     
    class_norms = class_norms.detach().cpu().numpy()
    angle_matrix = angle_matrix.detach().cpu().numpy()

     
    plt.figure(figsize=(10, 5))
    plt.bar(range(num_classes), class_norms)
    plt.title('L2 Norms of Class Vectors')
    plt.xlabel('Class')
    plt.ylabel('L2 Norm')
    plt.savefig('cls_class_norms.png')

     
    plt.figure(figsize=(10, 10))
    plt.imshow(angle_matrix, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.title('Angle Matrix')
    plt.xlabel('Class')
    plt.ylabel('Class')
    plt.savefig('cls_angle_matrix.png')
    plt.show()


    class_norms = {label: torch.norm(mean, p=2) for label, mean in class_means.items()}



    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global




def globaltest_feat_collapse(net, g_head, test_dataset, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]

     
    class_sums = {}
    class_counts = {}
     
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
             
            if head_switch == True:
                outputs = g_head(features)
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
             
                 
            for i in range(images.size(0)):
                 
                label = labels[i].item()
                feature = features[i]

                 
                if label not in class_sums:
                    class_sums[label] = feature
                    class_counts[label] = 1
                else:
                    class_sums[label] += feature
                    class_counts[label] += 1
         

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    print("默认inference性能:")
    print(acc, acc_3shot_global)
    
     
    class_means = {label: class_sum / class_counts[label] for label, class_sum in class_sums.items()}

     


     
     
    zero_positions = {}
    beta = 0.9
     
    for class_label, tensor in class_means.items():
         
        sorted_values, _ = torch.sort(tensor)
         
         
        threshold_index = int(len(sorted_values) * (1-beta))
        threshold_value = sorted_values[threshold_index]
        
         
        zero_positions[class_label] = (tensor < threshold_value)
        tensor[tensor < threshold_value] = 0

         
         
         



     
    original_weight = g_head.weight.detach().clone()
    original_bias = g_head.bias.detach().clone()
    for key, value in zero_positions.items():
            original_weight[key, :] *= (~value).float()   

     
    new_g_head = nn.Linear(in_features=512, out_features=100, bias=True)
    new_g_head.weight.data = original_weight
    new_g_head.bias.data = original_bias
    new_g_head = new_g_head.to(args.device)

     
         
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
             
            if head_switch == True:
                outputs = new_g_head(features)
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
             
                 
            for i in range(images.size(0)):
                 
                label = labels[i].item()
                feature = features[i]

                 
                if label not in class_sums:
                    class_sums[label] = feature
                    class_counts[label] = 1
                else:
                    class_sums[label] += feature
                    class_counts[label] += 1
         

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       

    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    print("裁剪后inference性能:")
    print(acc, acc_3shot_global)
    a = 1

     
     
     
     
     
     
     
     
     
     

     
     
     
     
     
     
     
     

     
    zero_positions = {}
    beta = 0.4
     
    for class_label, tensor in class_means.items():
         
        sorted_values, _ = torch.sort(tensor)
         
        threshold_index = int(len(sorted_values) * beta)
        threshold_value = sorted_values[threshold_index]
         
        zero_positions[class_label] = (tensor < threshold_value)
         
        tensor[tensor < threshold_value] = 0

     
    variance_zero_pos = 0.0
    variance_non_zero_pos = 0.0
    mean_zero_pos = 0.0
    mean_non_zero_pos = 0.0
    count_zero_pos = 0
    count_non_zero_pos = 0
    class_means = {label: class_sum / class_counts[label] for label, class_sum in class_sums.items()}
     
    for images, labels in test_loader:
        images = images.to(args.device)
        labels = labels.to(args.device)
        features = net(images, latent_output=True)
        
         
        for i, label in enumerate(labels):
            label_item = label.item()
            
             
            class_mean = class_means[label_item]
            zero_pos = zero_positions[label_item]
            
             
            diff = torch.abs(features[i] - class_mean)
            diff_zero_pos = diff[zero_pos]
            diff_non_zero_pos = diff[~zero_pos]
            
            variance_zero_pos += torch.sum(diff_zero_pos ** 2).item()
            variance_non_zero_pos += torch.sum(diff_non_zero_pos ** 2).item()
            
            mean_zero_pos += torch.sum(diff_zero_pos).item()
            mean_non_zero_pos += torch.sum(diff_non_zero_pos).item()
            
            count_zero_pos += len(diff_zero_pos)
            count_non_zero_pos += len(diff_non_zero_pos)

     
    avg_variance_zero_pos = variance_zero_pos / count_zero_pos if count_zero_pos > 0 else 0.0
    avg_variance_non_zero_pos = variance_non_zero_pos / count_non_zero_pos if count_non_zero_pos > 0 else 0.0

    avg_mean_zero_pos = mean_zero_pos / count_zero_pos if count_zero_pos > 0 else 0.0
    avg_mean_non_zero_pos = mean_non_zero_pos / count_non_zero_pos if count_non_zero_pos > 0 else 0.0

     
    relative_variance_zero_pos = avg_variance_zero_pos / (avg_mean_zero_pos ** 2) if avg_mean_zero_pos != 0 else 0.0
    relative_variance_non_zero_pos = avg_variance_non_zero_pos / (avg_mean_non_zero_pos ** 2) if avg_mean_non_zero_pos != 0 else 0.0

    print(f"Relative variance at zero_positions: {relative_variance_zero_pos}")
    print(f"Relative variance at non-zero_positions: {relative_variance_non_zero_pos}")




    import matplotlib.pyplot as plt
    import numpy as np

     
    features_list = []

     
    first_class_label = 8

     
    for images, labels in test_loader:
        images = images.to(args.device)
        labels = labels.to(args.device)
        features = net(images, latent_output=True)
        
         
        indices = (labels == first_class_label).nonzero(as_tuple=True)[0]
        first_class_features = features[indices]
        
        features_list.append(first_class_features.cpu().detach().numpy())

     
    all_features = np.concatenate(features_list, axis=0)

     
    mean_features = np.mean(all_features, axis=0)
    variance_features = np.var(all_features, axis=0)

     
    relative_variance = variance_features / (mean_features ** 2)
    relative_variance[np.isnan(relative_variance)] = 0   

     
    sorted_indices = np.argsort(mean_features)[::-1]
    sorted_mean_features = mean_features[sorted_indices]
    sorted_relative_variance = relative_variance[sorted_indices]


    cls_zero_positions = g_head.weight.cpu().detach().numpy()[first_class_label] == 0
    cls_zero_positions = cls_zero_positions[sorted_indices]

     
    fig, ax1 = plt.subplots()

    ax1.set_xlabel('Sorted Feature Index')
    ax1.set_ylabel('Sorted Mean Features', color='tab:blue')
    ax1.plot(sorted_mean_features, color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')

     
    count = 0
    for i in range(len(cls_zero_positions)):
        if cls_zero_positions[i] == True:
            if i < 256:
                count += 1
            ax1.axvspan(i-0.5, i+0.5, facecolor='gray', alpha=0.5)

    ax2 = ax1.twinx()
    ax2.set_ylabel('Sorted Relative Variance', color='tab:red')
    ax2.plot(sorted_relative_variance, color='tab:red')
    ax2.tick_params(axis='y', labelcolor='tab:red')

    fig.tight_layout()
    plt.title('Sorted Mean Features and Relative Variance for Class 0')

     
    fig.savefig('sorted_mean_and_relative_variance.png')


     
     
     

     
     
     

     
     

     
     
     
     
     
        
     
     
     
        
     
     
     
        
     
     

     
     
     

     
     
     

     
     
     
     

     
     

     
     
     
     

     
     
     
     

     
     

     
     



     
     
     

     
     

     
     

     
     
     
     
     
     
            
     
     
     
     
            
     
     
     
            
     

     
     

     
     
     

     
     

     
     
     

     
     

     
     
     
     

     
     
     

     
     


    return acc, acc_3shot_global



 
def globaltest_class_mean_filter(net, g_head, test_dataset, class_means, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]


     
    zero_positions = {}
    beta = 0.001
     
    for class_label, tensor in class_means.items():
         
        sorted_values, _ = torch.sort(tensor)
         
        threshold_index = int(len(sorted_values) * beta)
        threshold_value = sorted_values[threshold_index]
         
        zero_positions[class_label] = (tensor < threshold_value)
         
        tensor[tensor < threshold_value] = 0


     
    original_weight = g_head.weight.detach().clone()
    original_bias = g_head.bias.detach().clone()
    for key, value in zero_positions.items():
            original_weight[key, :] *= value.float()   

     
    new_g_head = nn.Linear(in_features=512, out_features=100, bias=True)
    new_g_head.weight.data = original_weight
    new_g_head.bias.data = original_bias
    new_g_head = new_g_head.to(args.device)

     
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
             
             
             
             
            if head_switch == True:
                 
                outputs = new_g_head(features)
            else:
                outputs = features
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()



             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       

    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)

    return acc, acc_3shot_global


def globaltest_etf(net, g_head, test_dataset, args, dataset_class=None, head_switch=True):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head": 0, "middle": 0, "tail": 0}
    total_3shot = {"head": 0, "middle": 0, "tail": 0}
    acc_3shot_global = {"head": None, "middle": None, "tail": None}
    net.eval()
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False)
     
    total_class_label = [0 for i in range(args.num_classes)]
    predict_true_class = [0 for i in range(args.num_classes)]
    cur_M = g_head.ori_M
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)
            if head_switch == True:
                outputs = g_head(features)
            else:
                outputs = features
            outputs = torch.matmul(outputs, cur_M)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
            for i in range(len(labels)):
                total_class_label[int(labels[i])] += 1       
                if predicted[i] == labels[i]:
                    predict_true_class[int(labels[i])] += 1

             
            for label in labels:
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             
    acc_class_wise = [predict_true_class[i] / total_class_label[i] for i in range(args.num_classes)]
    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / \
        (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / \
        (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / \
        (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global


 
def globaltest_GBA_Layer(backbone, classifier, test_dataset, args, dataset_class = None):
    global_test_distribution = dataset_class.global_test_distribution
    three_shot_dict, _ = shot_split(global_test_distribution, threshold_3shot=[75, 95])
    correct_3shot = {"head":0, "middle":0, "tail":0}     
    total_3shot = {"head":0, "middle":0, "tail":0} 
    acc_3shot_global = {"head":None, "middle":None, "tail":None}
    backbone.eval()
    classifier.eval()
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
     
    distri_class_real = [0 for i in range(100)]
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(args.device)
            labels = labels.to(args.device)
            feat = backbone(images)
            outputs = classifier(feat)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

             
            for label in labels:
                distri_class_real[int(label)] += 1       
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:    
                    correct_3shot["middle"] += 1
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:    
                    correct_3shot["tail"] += 1       
             

    acc = correct / total
    acc_3shot_global["head"] = correct_3shot["head"] / (total_3shot["head"] + 1e-10)
    acc_3shot_global["middle"] = correct_3shot["middle"] / (total_3shot["middle"] + 1e-10)
    acc_3shot_global["tail"] = correct_3shot["tail"] / (total_3shot["tail"] + 1e-10)
    return acc, acc_3shot_global

def localtest_villina(net, test_dataset, dataset_class, idxs, user_id):
    from sklearn.metrics import f1_score
    import copy
    args = dataset_class.get_args()
    net.eval()
    test_loader = torch.utils.data.DataLoader(DatasetSplit(
        test_dataset, idxs), batch_size=args.local_bs, shuffle=False)

     
     
     
     
    class_distribution_dict = {}

    class_distribution = dataset_class.local_test_distribution[user_id]

     
     
     
     

    three_shot_dict, _ = shot_split(
        class_distribution, threshold_3shot=[75, 95])
     

    ypred = []
    ytrue = []
    acc_3shot_local = {"head": None, "middle": None, "tail": None}

    with torch.no_grad():
        correct = 0
        total = 0
        correct_3shot = {"head": 0, "middle": 0, "tail": 0}
        total_3shot = {"head": 0, "middle": 0, "tail": 0}
        correct_classwise = [0 for i in range(args.num_classes)]
        total_classwise = [0 for i in range(args.num_classes)]
        acc_classwise = [0 for i in range(args.num_classes)]
        for images, labels in test_loader:
             
            images = images.to(args.device)
            labels = labels.to(args.device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)

             
            total += labels.size(0)      
             
            correct += (predicted == labels).sum().item()
            predicted = predicted.tolist()
            gts = copy.deepcopy(labels)
            gts = gts.tolist()
            ypred.append(predicted)
            ytrue.append(gts)
             
             
             

             
            for label in labels:
                total_classwise[label.cpu().tolist()] += 1 
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i]:
                    correct_classwise[labels[i].cpu().tolist()] += 1 
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             



    ypred = sum(ypred, [])
    ytrue = sum(ytrue, [])
     
     
    f1_macro = f1_score(y_true=ytrue, y_pred=ypred, average='macro')
    f1_weighted = f1_score(y_true=ytrue, y_pred=ypred, average='weighted')
     
     
    acc = correct / total

     
     
     
    acc_3shot_local["head"] = [0, False] if total_3shot["head"] == 0 else [
        (correct_3shot["head"] / total_3shot["head"]), True]
    acc_3shot_local["middle"] = [0, False] if total_3shot["middle"] == 0 else [
        (correct_3shot["middle"] / total_3shot["middle"]), True]
    acc_3shot_local["tail"] = [0, False] if total_3shot["tail"] == 0 else [
        (correct_3shot["tail"] / total_3shot["tail"]), True]
     
    for i in range(len(acc_classwise)):
        acc_classwise[i] = correct_classwise[i] / (total_classwise[i]+1e-10)
     
     
    return acc, f1_macro, f1_weighted, acc_3shot_local


def localtest(net, g_head, l_head, test_dataset, dataset_class, idxs, user_id):
    from sklearn.metrics import f1_score
    import copy
    args = dataset_class.get_args()
    net.eval()
    test_loader = torch.utils.data.DataLoader(DatasetSplit(
        test_dataset, idxs), batch_size=args.local_bs, shuffle=False)

     
     
     
     
    class_distribution_dict = {}

    class_distribution = dataset_class.local_test_distribution[user_id]

     
    p_mode = 1

    if p_mode == 1:
         
        zero_classes = np.where(class_distribution == 0)[0]
        for i in zero_classes:
            g_head.weight.data[i, :] = -1e10
            l_head.weight.data[i, :] = -1e10
    elif p_mode == 2:
         
        norm = torch.norm(l_head.weight, p=2, dim=1)
         
        g_head.weight = nn.Parameter(g_head.weight * norm.unsqueeze(1))
    elif p_mode == 3:
         
        class_distribution_tensor = torch.from_numpy(class_distribution)
         
         
        class_distribution_tensor = class_distribution_tensor.view(-1, 1)
         
        class_distribution_tensor = class_distribution_tensor.to(g_head.weight.device)
         
        g_head.weight = nn.Parameter(g_head.weight * class_distribution_tensor)
    elif p_mode == 4:
         
        g_head.weight = nn.Parameter(g_head.weight + l_head.weight)
        g_head.bias = nn.Parameter(g_head.bias + l_head.bias)
    elif p_mode == 5:
        g_head.weight = nn.Parameter(g_head.weight * l_head.weight)
         



    three_shot_dict, _ = shot_split(
        class_distribution, threshold_3shot=[75, 95])
     
    
    ypred = []
    ytrue = []
    acc_3shot_local = {"head": None, "middle": None, "tail": None}

    with torch.no_grad():
        correct = 0
        total = 0
        correct_3shot = {"head": 0, "middle": 0, "tail": 0}
        total_3shot = {"head": 0, "middle": 0, "tail": 0}
        correct_classwise = [0 for i in range(args.num_classes)]
        total_classwise = [0 for i in range(args.num_classes)]
        acc_classwise = [0 for i in range(args.num_classes)]
        for images, labels in test_loader:
             
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)

            if p_mode != 8:
                outputs = l_head(features) 
                 
                _, predicted = torch.max(outputs.data, 1)
            else:
                  
                l_outputs = l_head(features)
                
                 
                top_30_percent = int(0.1 * l_outputs.size(1))
                _, top_classes = l_outputs.topk(top_30_percent, dim=1)
                
                 
                mask = torch.zeros_like(l_outputs).scatter_(1, top_classes, 1).bool()
                
                 
                g_outputs = g_head(features)
                
                 
                masked_g_outputs = g_outputs.masked_fill(~mask, float('-inf')) 
                
                _, predicted = torch.max(masked_g_outputs.data, 1)

             
            total += labels.size(0)      
             
            correct += (predicted == labels).sum().item()
            predicted = predicted.tolist()
            gts = copy.deepcopy(labels)
            gts = gts.tolist()
            ypred.append(predicted)
            ytrue.append(gts)
             
             
             

             
            for label in labels:
                total_classwise[label.cpu().tolist()] += 1 
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i]:
                    correct_classwise[labels[i].cpu().tolist()] += 1 
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             



    ypred = sum(ypred, [])
    ytrue = sum(ytrue, [])
     
     
    f1_macro = f1_score(y_true=ytrue, y_pred=ypred, average='macro')
    f1_weighted = f1_score(y_true=ytrue, y_pred=ypred, average='weighted')
     
     
    acc = correct / total

     
     
     
    acc_3shot_local["head"] = [0, False] if total_3shot["head"] == 0 else [
        (correct_3shot["head"] / total_3shot["head"]), True]
    acc_3shot_local["middle"] = [0, False] if total_3shot["middle"] == 0 else [
        (correct_3shot["middle"] / total_3shot["middle"]), True]
    acc_3shot_local["tail"] = [0, False] if total_3shot["tail"] == 0 else [
        (correct_3shot["tail"] / total_3shot["tail"]), True]
     
    for i in range(len(acc_classwise)):
        acc_classwise[i] = correct_classwise[i] / (total_classwise[i]+1e-10)
     
     
    return acc, f1_macro, f1_weighted, acc_3shot_local


def localtest_etf(net, g_head, l_head, test_dataset, dataset_class, idxs, user_id):
    from sklearn.metrics import f1_score
    import copy
    args = dataset_class.get_args()
    net.eval()
    test_loader = torch.utils.data.DataLoader(DatasetSplit(
        test_dataset, idxs), batch_size=args.local_bs, shuffle=False)

     
     
     
     
    class_distribution_dict = {}

    class_distribution = dataset_class.local_test_distribution[user_id]


    p_mode = 1
    cur_M = g_head.ori_M

    if p_mode == 1:
        a = 1
         
         
         
         
         
    elif p_mode == 2:
         
        norm = torch.norm(l_head.weight, p=2, dim=1)
         
        g_head.weight = nn.Parameter(g_head.weight * norm.unsqueeze(1))
    elif p_mode == 3:
         
        class_distribution_tensor = torch.from_numpy(class_distribution)
         
         
        class_distribution_tensor = class_distribution_tensor.view(-1, 1)
         
        class_distribution_tensor = class_distribution_tensor.to(g_head.weight.device)
         
        g_head.weight = nn.Parameter(g_head.weight * class_distribution_tensor)
    elif p_mode == 4:
         
        g_head.weight = nn.Parameter(g_head.weight + l_head.weight)
        g_head.bias = nn.Parameter(g_head.bias + l_head.bias)
    elif p_mode == 5:
        g_head.weight = nn.Parameter(g_head.weight * l_head.weight)
         



    three_shot_dict, _ = shot_split(
        class_distribution, threshold_3shot=[75, 95])
     
    
    ypred = []
    ytrue = []
    acc_3shot_local = {"head": None, "middle": None, "tail": None}

    with torch.no_grad():
        correct = 0
        total = 0
        correct_3shot = {"head": 0, "middle": 0, "tail": 0}
        total_3shot = {"head": 0, "middle": 0, "tail": 0}
        correct_classwise = [0 for i in range(args.num_classes)]
        total_classwise = [0 for i in range(args.num_classes)]
        acc_classwise = [0 for i in range(args.num_classes)]
        for images, labels in test_loader:
             
            images = images.to(args.device)
            labels = labels.to(args.device)
            features = net(images, latent_output=True)

            if p_mode != 8:
                features = g_head(features)
                 
                outputs = torch.matmul(g_head(features), cur_M)
                 
                _, predicted = torch.max(outputs.data, 1)
            elif p_mode == 8:
                  
                l_outputs = l_head(features)
                
                 
                top_30_percent = int(0.1 * l_outputs.size(1))
                _, top_classes = l_outputs.topk(top_30_percent, dim=1)
                
                 
                mask = torch.zeros_like(l_outputs).scatter_(1, top_classes, 1).bool()
                
                 
                g_outputs = g_head(features)
                
                 
                masked_g_outputs = g_outputs.masked_fill(~mask, float('-inf')) 
                
                _, predicted = torch.max(masked_g_outputs.data, 1)

             
            total += labels.size(0)      
             
            correct += (predicted == labels).sum().item()
            predicted = predicted.tolist()
            gts = copy.deepcopy(labels)
            gts = gts.tolist()
            ypred.append(predicted)
            ytrue.append(gts)
             
             
             

             
            for label in labels:
                total_classwise[label.cpu().tolist()] += 1 
                if label in three_shot_dict["head"]:
                    total_3shot["head"] += 1
                elif label in three_shot_dict["middle"]:
                    total_3shot["middle"] += 1
                else:
                    total_3shot["tail"] += 1
            for i in range(len(predicted)):
                if predicted[i] == labels[i]:
                    correct_classwise[labels[i].cpu().tolist()] += 1 
                if predicted[i] == labels[i] and labels[i] in three_shot_dict["head"]:    
                    correct_3shot["head"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["middle"]:
                    correct_3shot["middle"] += 1
                 
                elif predicted[i] == labels[i] and labels[i] in three_shot_dict["tail"]:
                    correct_3shot["tail"] += 1       
             



    ypred = sum(ypred, [])
    ytrue = sum(ytrue, [])
     
     
    f1_macro = f1_score(y_true=ytrue, y_pred=ypred, average='macro')
    f1_weighted = f1_score(y_true=ytrue, y_pred=ypred, average='weighted')
     
     
    acc = correct / total

     
     
     
    acc_3shot_local["head"] = [0, False] if total_3shot["head"] == 0 else [
        (correct_3shot["head"] / total_3shot["head"]), True]
    acc_3shot_local["middle"] = [0, False] if total_3shot["middle"] == 0 else [
        (correct_3shot["middle"] / total_3shot["middle"]), True]
    acc_3shot_local["tail"] = [0, False] if total_3shot["tail"] == 0 else [
        (correct_3shot["tail"] / total_3shot["tail"]), True]
     
    for i in range(len(acc_classwise)):
        acc_classwise[i] = correct_classwise[i] / (total_classwise[i]+1e-10)
     
     
    return acc, f1_macro, f1_weighted, acc_3shot_local


def localtest_vallina(net, test_dataset, dataset_class, idxs, user_id):
    from sklearn.metrics import f1_score
    import copy
    args = dataset_class.get_args()
    net.eval()
    test_loader = torch.utils.data.DataLoader(DatasetSplit(
        test_dataset, idxs), batch_size=args.local_bs, shuffle=False)


    ypred = []
    ytrue = []
    acc_3shot_local = {"head": None, "middle": None, "tail": None}

    with torch.no_grad():
        correct = 0
        total = 0
        correct_3shot = {"head": 0, "middle": 0, "tail": 0}
        total_3shot = {"head": 0, "middle": 0, "tail": 0}
        correct_classwise = [0 for i in range(args.num_classes)]
        total_classwise = [0 for i in range(args.num_classes)]
        acc_classwise = [0 for i in range(args.num_classes)]
        for images, labels in test_loader:
             
            images = images.to(args.device)
            labels = labels.to(args.device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)

             
            total += labels.size(0)      
             
            correct += (predicted == labels).sum().item()
            predicted = predicted.tolist()
            gts = copy.deepcopy(labels)
            gts = gts.tolist()
            ypred.append(predicted)
            ytrue.append(gts)
             


    ypred = sum(ypred, [])
    ytrue = sum(ytrue, [])
     
     
    f1_macro = f1_score(y_true=ytrue, y_pred=ypred, average='macro')
    f1_weighted = f1_score(y_true=ytrue, y_pred=ypred, average='weighted')
     
     
    acc = correct / total
    return acc, f1_macro, f1_weighted, acc_3shot_local




def calculate_metrics(pred_np, seg_np):
     
     
    b = len(pred_np)
    all_f1 = []
    all_sensitivity = []
    all_specificity = []
    all_ppv = []
    all_npv = []
    for i in range(b):

        f1 = metrics.f1_score(seg_np[i], pred_np[i], average='macro')

         
         
         
         
         

         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         

        all_f1.append(f1)
         
         
         
         
     
    return all_f1




def compute_angle(label1, label2, class_means):
     
    mean1 = class_means[label1]
    mean2 = class_means[label2]

     
    dot_product = torch.dot(mean1, mean2)
    norm1 = torch.norm(mean1)
    norm2 = torch.norm(mean2)
    cos_theta = dot_product / (norm1 * norm2)
    theta = torch.acos(cos_theta)
    return torch.rad2deg(theta)