import torch

from .basic_server import SyncServerHandler
from .basic_client import SGDClientTrainer, SGDSerialClientTrainer
from ...utils import Aggregators
from tqdm import tqdm
import torch.nn as nn
import pdb
import numpy as np
np.random.seed(886) 
import random
random.seed(886)
from ...utils.loss import FocalLoss, CrossEntropy, LabelSmoothingLoss, ClassficationAndMDCA, BrierScore, DCA, MMCE, FLSD
from ...utils.logit_margin_l1 import LogitMarginL1
from ...utils.loss import addinitial_DCA, addinitial_ClassficationAndMDCA
from ...utils.loss import Inverse_DCA, Balance_DCA , Inverse_ClassficationAndMDCA, Balance_ClassficationAndMDCA
from ...utils.CKA_func import CudaCKA

##################
#
#      Server
#
##################


class ScaffoldServerHandler(SyncServerHandler):
    """FedAvg server handler."""
    @property
    def downlink_package(self):
        return [self.model_parameters, self.global_c]

    def setup_optim(self, lr):
        self.lr = lr
        self.global_c = torch.zeros_like(self.model_parameters)

    def global_update(self, buffer):
        # unpack
        dys = [ele[0] for ele in buffer]
        dcs = [ele[1] for ele in buffer]

        dx = Aggregators.fedavg_aggregate(dys)
        dc = Aggregators.fedavg_aggregate(dcs)

        next_model = self.model_parameters + self.lr * dx
        self.set_model(next_model)

        self.global_c += 1.0 * len(dcs) / self.num_clients * dc


##################
#
#      Client
#
##################


class ScaffoldSerialClientTrainer(SGDSerialClientTrainer):
    def setup_optim(self, epochs, batch_size, lr):
        super().setup_optim(epochs, batch_size, lr)
        self.cs = [None for _ in range(self.num_clients)]

    def local_process(self, payload, id_list):
        model_parameters = payload[0]
        global_c = payload[1]
        client_dict_per_round = {}
        # for id in id_list:
        for id in (progress_bar := tqdm(id_list)):
            data_loader = self.dataset.get_dataloader(id, self.batch_size)
            # pack = self.train(id, model_parameters, global_c, data_loader)
            pack, big_gold, big_pred, big_out_prob = self.train(id, model_parameters, global_c, data_loader) # this
            # client_dict_per_round[id] = {'gold': big_gold, 'pred': big_pred}
            client_dict_per_round[id] = {'gold': big_gold, 'pred': big_pred, 'prob': big_out_prob}
            self.cache.append(pack)
        return client_dict_per_round

    def train(self, id, model_parameters, global_c, train_loader):
        self.set_model(model_parameters)

        ########
        # this, evaluate the broadcasted global model on each client's training samples
        big_list_out_prob = []
        big_list_gold = []
        big_list_pred = []

        list_broadcast_out_prob = []
        list_broadcast_gold = []
        list_broadcast_pred = []
        self._model.eval()

        with torch.no_grad():

            initial_add = None

            params_init = []
            for param1 in self._model.parameters():
                params_init.append(param1.view(-1))
            params_init = torch.cat(params_init)

            for data, target in train_loader:

                # for femnist
                data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
                target = torch.tensor(target, dtype=torch.long)
                # for femnist


                if self.cuda:
                    data = data.cuda(self.device)
                    target = target.cuda(self.device)
                
                output = self.model(data)
                list_broadcast_out_prob.extend(nn.Softmax(dim=1)(output).tolist())
                _, predicted = torch.max(output, 1)
                list_broadcast_pred.extend(predicted.tolist())
                list_broadcast_gold.extend(target.tolist())

                # DCA
                init_output = torch.softmax(output, dim=1)
                init_conf, init_pred_labels = torch.max(init_output, dim = 1)
                init_calib_loss = torch.abs(init_conf.mean() -  (init_pred_labels == target).float().mean())
                if initial_add == None:
                    initial_add = init_calib_loss.item()
                else:
                    initial_add += init_calib_loss.item()

            #     # # MDCA
            #     init_output = torch.softmax(output, dim=1)
            #     batch, classes = output.shape
            #     for c in range(classes):
            #         avg_count = (target == c).float().mean()
            #         avg_conf = torch.mean(output[:,c])
            #         if initial_add == None:
            #             initial_add = torch.abs(avg_conf - avg_count).item()
            #         else:
            #             initial_add += torch.abs(avg_conf - avg_count).item()
            # denom = classes
            # initial_add /= denom
                    

            big_list_out_prob.append(list_broadcast_out_prob)
            big_list_gold.append(list_broadcast_gold)
            big_list_pred.append(list_broadcast_pred)
            ########



        frz_model = model_parameters

        if self.cs[id] is None:
            self.cs[id] = torch.zeros_like(model_parameters)

        for _ in range(self.epochs):
            list_out_prob = []
            list_gold = []
            list_pred = []
            for data, target in train_loader:

                # for femnist
                data = torch.tensor(data,dtype=torch.float32).reshape(-1, 1, 28, 28)
                target = torch.tensor(target, dtype=torch.long)
                # for femnist


                if self.cuda:
                    data = data.cuda(self.device)
                    target = target.cuda(self.device)

                output = self.model(data)
                
                _, predicted = torch.max(output, 1) # this
                # list_out_prob.extend(output.tolist())
                list_out_prob.extend(nn.Softmax(dim=1)(output).tolist()) # this
                list_gold.extend(target.tolist())
                list_pred.extend(predicted.tolist())

                # loss = CrossEntropy().forward(output, target)
                # loss = FocalLoss().forward(output, target) # need to tune gamma
                # loss = LabelSmoothingLoss().forward(output, target) # need to tune alpha
                # loss = ClassficationAndMDCA().forward(output, target)
                # loss = BrierScore().forward(output, target)
                # loss = DCA().forward(output, target)
                # loss = MMCE().forward(output, target)
                # loss = FLSD().forward(output, target)
                # loss = LogitMarginL1().forward(output, target)

                # loss = addinitial_DCA().forward(output, target, initial_add)
                # loss = addinitial_ClassficationAndMDCA().forward(output, target, initial_add)

                # loss = self.criterion(output, target)
                

                self.optimizer.zero_grad()


                # ### CKA
                loss = CrossEntropy().forward(output, target)
                loss.backward(retain_graph=True)
                cuda_cka = CudaCKA(self.device)
                # # MNIST
                # Local_Gradient = self._model.fc1.weight.grad
                # Global = torch.ones(self._model.fc1.weight.grad.size())

                # # CIFAR
                # Local_Gradient = self._model.classifier[4].weight.grad
                # Global = torch.ones(self._model.classifier[4].weight.grad.size())

                # FEMNIST
                Local_Gradient = self._model.linear_2.weight.grad
                Global = torch.ones(self._model.linear_2.weight.grad.size())


                Global = torch.mul(Global, 1e-3).cuda(self.device)
                score = cuda_cka.linear_CKA(Local_Gradient,Global)
                # score = cuda_cka.kernel_CKA(Local_Gradient,Global)
                score = score.item()
                score = abs(score)


                # # ### COSINE
                # cos = torch.nn.CosineSimilarity(dim=0)
                # score = torch.sum(cos(Local_Gradient,Global))/cos(Local_Gradient,Global).size(dim = 0)
                # score = score.item()
                # score = abs(score)

                # loss = Inverse_DCA().forward(output, target, initial_add, score)
                # loss = Balance_DCA().forward(output, target, initial_add, score)
                loss = Inverse_ClassficationAndMDCA().forward(output, target, initial_add, score)
                # loss = Balance_ClassficationAndMDCA().forward(output, target, initial_add, score)


                # ###
                # cos = torch.nn.CosineSimilarity(dim=0)
                # params = []
                # for param in self.model.parameters():
                #     params.append(param.view(-1))
                # params = torch.cat(params)

                # with torch.no_grad():
                #     cos_sim = cos(params, params_init).detach().item()
                # score = abs(cos_sim)



                # self.optimizer.zero_grad()
                loss.backward()

                # grad = self.model_gradients
                grad = self.model_grads
                
                    
                grad = grad - self.cs[id]
                idx = 0
                
                parameters = self._model.parameters()
                for p in self._model.state_dict().values():
                    if p.grad is None: # Batchnorm have no grad
                        layer_size = p.numel()
                    else:
                        parameter = next(parameters)
                        layer_size = parameter.data.numel()
                        shape = parameter.grad.shape
                        parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
                    idx += layer_size

                # for parameter in self._model.parameters():
                #     layer_size = parameter.grad.numel()
                #     shape = parameter.grad.shape
                #     #parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
                #     parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
                #     idx += layer_size

                self.optimizer.step()
            big_list_out_prob.append(list_out_prob)
            big_list_gold.append(list_gold)
            big_list_pred.append(list_pred)
        
        dy = self.model_parameters - frz_model
        
        dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c
        self.cs[id] += dc


        # dy = model_parameters_cuda - frz_model
        
        # dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c.cuda(self.device)
        # self.cs[id] += dc
        
        return [dy, dc], big_list_gold, big_list_pred, big_list_out_prob
        
