from .basic_server import SyncServerHandler
from .basic_client import SGDClientTrainer, SGDSerialClientTrainer
from ...utils.aggregator import Aggregators
from ...utils.serialization import SerializationTool
import torch
import torch.nn as nn

import pdb
import numpy as np
np.random.seed(886) 
import random
random.seed(886)

from ...utils.loss import FocalLoss, CrossEntropy, LabelSmoothingLoss, ClassficationAndMDCA, BrierScore, DCA, MMCE, FLSD
from ...utils.loss import addinitial_DCA, addinitial_ClassficationAndMDCA
from ...utils.logit_margin_l1 import LogitMarginL1
from torch_cka import CKA
import copy
from torch.utils.data import DataLoader, TensorDataset
from ...utils.loss import Inverse_DCA, Balance_DCA , Inverse_ClassficationAndMDCA, Balance_ClassficationAndMDCA
from ...utils.CKA_func import CudaCKA
 
    # "focal_loss" : FocalLoss,
    # "cross_entropy" : CrossEntropy,
    # "LS" : LabelSmoothingLoss,
    # "NLL+MDCA" : ClassficationAndMDCA,
    # "LS+MDCA" : ClassficationAndMDCA,
    # "FL+MDCA" : ClassficationAndMDCA,
    # "brier_loss" : BrierScore,
    # "NLL+DCA" : DCA,
    # "MMCE" : MMCE,
    # "FLSD" : FLSD

activation_pre = {}
def get_activation_pre(name):
    def hook(model, input, output):
        activation_pre[name] = output.detach()
    return hook

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

##################
#
#      Server
#
##################


class FedAvgServerHandler(SyncServerHandler):
    """FedAvg server handler."""
    def global_update(self, buffer):
        parameters_list = [ele[0] for ele in buffer]
        weights = [ele[1] for ele in buffer]
        serialized_parameters = Aggregators.fedavg_aggregate(parameters_list, weights)
        SerializationTool.deserialize_model(self._model, serialized_parameters)


##################
#
#      Client
#
##################


class FedAvgClientTrainer(SGDClientTrainer):
    """Federated client with local SGD solver."""
    def global_update(self, buffer):
        parameters_list = [ele[0] for ele in buffer]
        weights = [ele[1] for ele in buffer]
        serialized_parameters = Aggregators.fedavg_aggregate(
            parameters_list, weights)
        SerializationTool.deserialize_model(self._model, serialized_parameters)


class FedAvgSerialClientTrainer(SGDSerialClientTrainer):
    """Federated client with local SGD solver."""
    def train(self, model_parameters, train_loader):
        self.set_model(model_parameters)

        ########
        # this, evaluate the broadcasted global model on each client's training samples
        big_list_out_prob = []
        big_list_gold = []
        big_list_pred = []

        list_broadcast_out_prob = []
        list_broadcast_gold = []
        list_broadcast_pred = []
        self._model.eval()
        initial_add = None
        # pdb.set_trace()
        # copy_model = copy.deepcopy(self._model)
        # self._model.fc1.register_forward_hook(get_activation_pre('fc1'))
        params_init = []
        for param1 in self._model.parameters():
            params_init.append(param1.view(-1))
        params_init = torch.cat(params_init)

        for data, target in train_loader:
            
            # for femnist
            data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
            target = torch.tensor(target, dtype=torch.long)
            # for femnist
        


            if self.cuda:
                data = data.cuda(self.device)
                target = target.cuda(self.device)

            output = self.model(data)

            list_broadcast_out_prob.extend(nn.Softmax(dim=1)(output).tolist())
            _, predicted = torch.max(output, 1)
            list_broadcast_pred.extend(predicted.tolist())
            list_broadcast_gold.extend(target.tolist())
            # emb = activation['fc1']
            
            # cuda_cka = CudaCKA(self.device)
            # cuda_cka.linear_CKA(emb,emb)
            # cuda_cka.kernel_CKA(emb,emb)

            # DCA
            init_output = torch.softmax(output, dim=1)
            init_conf, init_pred_labels = torch.max(init_output, dim = 1)
            init_calib_loss = torch.abs(init_conf.mean() -  (init_pred_labels == target).float().mean())
            if initial_add == None:
                initial_add = init_calib_loss.item()
            else:
                initial_add += init_calib_loss.item()
            
        #     # # MDCA
        #     init_output = torch.softmax(output, dim=1)
        #     batch, classes = output.shape
        #     for c in range(classes):
        #         avg_count = (target == c).float().mean()
        #         avg_conf = torch.mean(output[:,c])
        #         if initial_add == None:
        #             initial_add = torch.abs(avg_conf - avg_count).item()
        #         else:
        #             initial_add += torch.abs(avg_conf - avg_count).item()
        # denom = classes
        # initial_add /= denom

        

        
        
        big_list_out_prob.append(list_broadcast_out_prob)
        big_list_gold.append(list_broadcast_gold)
        big_list_pred.append(list_broadcast_pred)
        ########

        self._model.train()

        data_size = 0
        for _ in range(self.epochs):
            list_out_prob = []
            list_gold = []
            list_pred = []
            all_targets = None
            all_outputs = None
            for batch_idx, (data, target) in enumerate(train_loader):

                # for femnist
                data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
                target = torch.tensor(target, dtype=torch.long)
                # for femnist

                if self.cuda:
                    data = data.cuda(self.device)
                    target = target.cuda(self.device)

                # self._model.fc1.register_forward_hook(get_activation('fc1'))
                output = self.model(data)
                _, predicted = torch.max(output, 1) # this

                # list_out_prob.extend(output.tolist())
                list_out_prob.extend(nn.Softmax(dim=1)(output).tolist()) # this
                
                list_gold.extend(target.tolist())
                list_pred.extend(predicted.tolist())
                # loss = self.criterion(output, target)

                # loss = CrossEntropy().forward(output, target)
                # loss = FocalLoss().forward(output, target) # need to tune gamma
                # loss = LabelSmoothingLoss().forward(output, target) # need to tune alpha
                # loss = ClassficationAndMDCA().forward(output, target)
                # loss = BrierScore().forward(output, target)
                # loss = DCA().forward(output, target)
                # loss = MMCE().forward(output, target)
                # loss = FLSD().forward(output, target)
                # loss = LogitMarginL1().forward(output, target)

                # loss = addinitial_DCA().forward(output, target, initial_add)
                # loss = addinitial_ClassficationAndMDCA().forward(output, target, initial_add)

                self.optimizer.zero_grad()
                



                # ### CKA
                loss = CrossEntropy().forward(output, target)
                loss.backward(retain_graph=True)
                cuda_cka = CudaCKA(self.device)
                # # MNIST
                # Local_Gradient = self._model.fc1.weight.grad
                # Global = torch.ones(self._model.fc1.weight.grad.size())

                # # CIFAR
                # Local_Gradient = self._model.classifier[4].weight.grad
                # Global = torch.ones(self._model.classifier[4].weight.grad.size())

                # FEMNIST
                Local_Gradient = self._model.linear_2.weight.grad
                Global = torch.ones(self._model.linear_2.weight.grad.size())


                Global = torch.mul(Global, 1e-3).cuda(self.device)
                # score = cuda_cka.linear_CKA(Local_Gradient,Global)
                score = cuda_cka.kernel_CKA(Local_Gradient,Global)
                score = score.item()
                score = abs(score)



                # # ### COSINE
                # cos = torch.nn.CosineSimilarity(dim=0)
                # score = torch.sum(cos(Local_Gradient,Global))/cos(Local_Gradient,Global).size(dim = 0)
                # score = score.item()
                # score = abs(score)

                # loss = Inverse_DCA().forward(output, target, initial_add, score)
                # loss = Balance_DCA().forward(output, target, initial_add, score)
                loss = Inverse_ClassficationAndMDCA().forward(output, target, initial_add, score)
                # loss = Balance_ClassficationAndMDCA().forward(output, target, initial_add, score)




                # output = self._model(data)
                # emb_pre = activation_pre['fc1']
                # emb = activation['fc1']
                # cuda_cka = CudaCKA(self.device)
                # sim_cka_linear = cuda_cka.linear_CKA(emb_pre,emb)
                # sim_cka_kernel = cuda_cka.kernel_CKA(emb_pre,emb)
                # # print(sim_cka_linear, sim_cka_kernel)


                
                # ###
                # cos = torch.nn.CosineSimilarity(dim=0)
                # params = []
                # for param in self.model.parameters():
                #     params.append(param.view(-1))
                # params = torch.cat(params)

                # with torch.no_grad():
                #     cos_sim = cos(params, params_init).detach().item()
                # score = abs(cos_sim)


                

                data_size += len(target)
                loss.backward()
                self.optimizer.step()


        return [self.model_parameters, data_size], big_list_gold, big_list_pred, big_list_out_prob 
