from copy import deepcopy
import torch

from .basic_server import SyncServerHandler
from .basic_client import SGDClientTrainer, SGDSerialClientTrainer
import pdb
from tqdm import tqdm
import torch.nn as nn

import numpy as np
np.random.seed(886) 
import random
random.seed(886)

from ...utils.loss import FocalLoss, CrossEntropy, LabelSmoothingLoss, ClassficationAndMDCA, BrierScore, DCA, MMCE, FLSD
from ...utils.logit_margin_l1 import LogitMarginL1
from ...utils.loss import addinitial_DCA, addinitial_ClassficationAndMDCA
from ...utils.loss import Inverse_DCA, Balance_DCA , Inverse_ClassficationAndMDCA, Balance_ClassficationAndMDCA
from ...utils.CKA_func import CudaCKA


    # "focal_loss" : FocalLoss,
    # "cross_entropy" : CrossEntropy,
    # "LS" : LabelSmoothingLoss,
    # "NLL+MDCA" : ClassficationAndMDCA,
    # "LS+MDCA" : ClassficationAndMDCA,
    # "FL+MDCA" : ClassficationAndMDCA,
    # "brier_loss" : BrierScore,
    # "NLL+DCA" : DCA,
    # "MMCE" : MMCE,
    # "FLSD" : FLSD


##################
#
#      Server
#
##################


class FedProxServerHandler(SyncServerHandler):
    """FedProx server handler."""
    None


##################
#
#      Client
#
##################

class FedProxClientTrainer(SGDClientTrainer):
    """Federated client with local SGD with proximal term solver."""
    # def setup_optim(self, epochs, batch_size, lr, mu):
    def setup_optim(self, epochs, batch_size, lr, mu):
        super().setup_optim(epochs, batch_size, lr, mu)
        self.mu = mu

    def local_process(self, payload, id):
        model_parameters = payload[0]
        train_loader = self.dataset.get_dataloader(id, self.batch_size)
        self.train(model_parameters, train_loader, self.mu)

    def train(self, model_parameters, train_loader, mu) -> None:
        """Client trains its local model on local dataset.

        Args:
            model_parameters (torch.Tensor): Serialized model parameters.
        """
        self.set_model(model_parameters)
        frz_model = deepcopy(self._model)
        for ep in range(self.epochs):
            self._model.train()
            for data, target in train_loader:
                if self.cuda:
                    data, target = data.cuda(self.device), target.cuda(
                        self.device)

                preds = self._model(data)
                l1 = self.criterion(preds, target)
                l2 = 0.0
                for w0, w in zip(frz_model.parameters(), self._model.parameters()):
                    l2 += torch.sum(torch.pow(w - w0, 2))

                loss = l1 + 0.5 * mu * l2

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        return [self.model_parameters]

class FedProxSerialClientTrainer(SGDSerialClientTrainer):
    def setup_optim(self, epochs, batch_size, lr, mu):
        super().setup_optim(epochs, batch_size, lr, mu)
        self.mu = mu

    def local_process(self, payload, id_list):
        model_parameters = payload[0]
        client_dict_per_round = {}
        # for id in id_list:
        for id in (progress_bar := tqdm(id_list)):
            progress_bar.set_description(f"Training on client {id}", refresh=True)
            data_loader = self.dataset.get_dataloader(id, self.batch_size)
            # pack = self.train(model_parameters, data_loader, self.mu)
            # pack, big_gold, big_pred = self.train(model_parameters, data_loader, self.mu) # this
            pack, big_gold, big_pred, big_out_prob = self.train(model_parameters, data_loader, self.mu) # this
            self.cache.append(pack)
            # client_dict_per_round[id] = {'gold': big_gold, 'pred': big_pred}
            client_dict_per_round[id] = {'gold': big_gold, 'pred': big_pred, 'prob': big_out_prob}
        
        return client_dict_per_round # this (original delete this line)

    def train(self, model_parameters, train_loader, mu) -> None:
        """Client trains its local model on local dataset.

        Args:
            model_parameters (torch.Tensor): serialized model parameters.
            train_loader (torch.utils.data.DataLoader): :class:`torch.utils.data.DataLoader` for this client.
            mu (float): parameter of FedProx.
            
        """
        self.set_model(model_parameters)

        ########
        # this, evaluate the broadcasted global model on each client's training samples
        big_list_out_prob = []
        big_list_gold = []
        big_list_pred = []

        list_broadcast_out_prob = []
        list_broadcast_gold = []
        list_broadcast_pred = []
        self._model.eval()

        with torch.no_grad():
            initial_add = None

            params_init = []
            for param1 in self._model.parameters():
                params_init.append(param1.view(-1))
            params_init = torch.cat(params_init)

            for data, target in train_loader:
                # for femnist
                data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
                target = torch.tensor(target, dtype=torch.long)
                # for femnist
                if self.cuda:
                    data = data.cuda(self.device)
                    target = target.cuda(self.device)
                
                output = self.model(data)
                list_broadcast_out_prob.extend(nn.Softmax(dim=1)(output).tolist())
                _, predicted = torch.max(output, 1)
                list_broadcast_pred.extend(predicted.tolist())
                list_broadcast_gold.extend(target.tolist())

                # # DCA
                # init_output = torch.softmax(output, dim=1)
                # init_conf, init_pred_labels = torch.max(init_output, dim = 1)
                # init_calib_loss = torch.abs(init_conf.mean() -  (init_pred_labels == target).float().mean())
                # if initial_add == None:
                #     initial_add = init_calib_loss.item()
                # else:
                #     initial_add += init_calib_loss.item()


                # # MDCA
                init_output = torch.softmax(output, dim=1)
                batch, classes = output.shape
                for c in range(classes):
                    avg_count = (target == c).float().mean()
                    avg_conf = torch.mean(output[:,c])
                    if initial_add == None:
                        initial_add = torch.abs(avg_conf - avg_count).item()
                    else:
                        initial_add += torch.abs(avg_conf - avg_count).item()
            denom = classes
            initial_add /= denom
                    
            
            big_list_out_prob.append(list_broadcast_out_prob)
            big_list_gold.append(list_broadcast_gold)
            big_list_pred.append(list_broadcast_pred)
            ########


        frz_model = deepcopy(self._model)
        frz_model.eval()

        for ep in range(self.epochs):
            self._model.train()
            list_out_prob = []
            list_gold = []
            list_pred = []
            for data, target in train_loader:

                # for femnist
                data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
                target = torch.tensor(target, dtype=torch.long)
                # for femnist

                if self.cuda:
                    data, target = data.cuda(self.device), target.cuda(
                        self.device)
                # pdb.set_trace()
                preds = self._model(data)
                _, predicted = torch.max(preds, 1) # this
                list_out_prob.extend(nn.Softmax(dim=1)(preds).tolist())
                list_gold.extend(target.tolist())
                list_pred.extend(predicted.tolist())
                

                # l1 = CrossEntropy().forward(preds, target)
                # l1 = FocalLoss().forward(preds, target) # need to tune gamma
                # l1 = LabelSmoothingLoss().forward(preds, target) # need to tune alpha
                # l1 = ClassficationAndMDCA().forward(preds, target)
                # l1 = BrierScore().forward(preds, target)
                # l1 = DCA().forward(preds, target)
                # l1 = MMCE().forward(preds, target)
                # l1 = FLSD().forward(preds, target)
                # l1 = LogitMarginL1().forward(preds, target)

                # l1 = addinitial_DCA().forward(preds, target, initial_add)
                # l1 = addinitial_ClassficationAndMDCA().forward(preds, target, initial_add)
                # l1 = self.criterion(preds, target)


                self.optimizer.zero_grad()


                # ### CKA
                loss = CrossEntropy().forward(preds, target)
                loss.backward(retain_graph=True)
                cuda_cka = CudaCKA(self.device)
                # # MNIST
                # Local_Gradient = self._model.fc1.weight.grad
                # Global = torch.ones(self._model.fc1.weight.grad.size())

                # # CIFAR
                # Local_Gradient = self._model.classifier[4].weight.grad
                # Global = torch.ones(self._model.classifier[4].weight.grad.size())

                # FEMNIST
                Local_Gradient = self._model.linear_2.weight.grad
                Global = torch.ones(self._model.linear_2.weight.grad.size())


                Global = torch.mul(Global, 1e-3).cuda(self.device)
                score = cuda_cka.linear_CKA(Local_Gradient,Global)
                # score = cuda_cka.kernel_CKA(Local_Gradient,Global)
                score = score.item()
                score = abs(score)


                # # ### COSINE
                # cos = torch.nn.CosineSimilarity(dim=0)
                # score = torch.sum(cos(Local_Gradient,Global))/cos(Local_Gradient,Global).size(dim = 0)
                # score = score.item()
                # score = abs(score)

                # ###
                # cos = torch.nn.CosineSimilarity(dim=0)
                # params = []
                # for param in self.model.parameters():
                #     params.append(param.view(-1))
                # params = torch.cat(params)

                # with torch.no_grad():
                #     cos_sim = cos(params, params_init).detach().item()
                # score = abs(cos_sim)

                

                # l1 = Inverse_DCA().forward(preds, target, initial_add, score)
                # l1 = Balance_DCA().forward(preds, target, initial_add, score)
                l1 = Inverse_ClassficationAndMDCA().forward(preds, target, initial_add, score)
                # l1 = Balance_ClassficationAndMDCA().forward(preds, target, initial_add, score)



                l2 = 0.0
                for w0, w in zip(frz_model.parameters(), self._model.parameters()):
                    l2 += torch.sum(torch.pow(w - w0, 2))

                loss = l1 + 0.5 * mu * l2

                loss.backward()
                self.optimizer.step()

            big_list_out_prob.append(list_out_prob)
            big_list_gold.append(list_gold)
            big_list_pred.append(list_pred)

        return [self.model_parameters], big_list_gold, big_list_pred, big_list_out_prob  # this
