import torch

from .basic_server import SyncServerHandler
from .basic_client import SGDClientTrainer, SGDSerialClientTrainer
from ...utils import Aggregators
from tqdm import tqdm
import torch.nn as nn

import numpy
numpy.random.seed(886) 
import random
random.seed(886)

from ...utils.loss import FocalLoss, CrossEntropy, LabelSmoothingLoss, ClassficationAndMDCA, BrierScore, DCA, MMCE, FLSD
from ...utils.logit_margin_l1 import LogitMarginL1
import pdb
from ...utils.loss import addinitial_DCA, addinitial_ClassficationAndMDCA
from ...utils.loss import Inverse_DCA, Balance_DCA , Inverse_ClassficationAndMDCA, Balance_ClassficationAndMDCA


##################
#
#      Server
#
##################


class FedDynServerHandler(SyncServerHandler):
    """FedAvg server handler."""
    def setup_optim(self, alpha):
        self.alpha = alpha
        self.h = torch.zeros_like(self.model_parameters)

    def global_update(self, buffer):
        parameters_list = [ele[0] for ele in buffer]
        deltas = sum([parameters-self.model_parameters for parameters in parameters_list])
        self.h = self.h - self.alpha * (1.0/self.num_clients) * deltas
        new_parameters = Aggregators.fedavg_aggregate(parameters_list) - 1.0 / self.alpha * self.h
        self.set_model(new_parameters)


##################
#
#      Client
#
##################


class FedDynSerialClientTrainer(SGDSerialClientTrainer):
    def __init__(self, model, num_clients, cuda=False, device=None, logger=None, personal=False) -> None:
        super().__init__(model, num_clients, cuda, device, logger, personal)

        self.L = [None for _ in range(num_clients)]


    def setup_dataset(self, dataset):
        return super().setup_dataset(dataset)

    def setup_optim(self, epochs, batch_size, lr, alpha):
        self.alpha = alpha
        super().setup_optim(epochs, batch_size, lr)

    def local_process(self, payload, id_list):
        model_parameters = payload[0]
        client_dict_per_round = {}
        # for id in id_list:
        for id in (progress_bar := tqdm(id_list)):
            data_loader = self.dataset.get_dataloader(id, self.batch_size)
            # pack = self.train(id, model_parameters, data_loader)
            pack, big_gold, big_pred, big_out_prob = self.train(id, model_parameters, data_loader) # this
            client_dict_per_round[id] = {'gold': big_gold, 'pred': big_pred, 'prob': big_out_prob}
            self.cache.append(pack)
        return client_dict_per_round

    def train(self, id, model_parameters, train_loader):
        
        ########
        # this, evaluate the broadcasted global model on each client's training samples
        
        big_list_out_prob = []
        big_list_gold = []
        big_list_pred = []

        list_broadcast_out_prob = []
        list_broadcast_gold = []
        list_broadcast_pred = []
        self._model.eval()
        initial_add = None

        params_init = []
        for param1 in self._model.parameters():
            params_init.append(param1.view(-1))
        params_init = torch.cat(params_init)

        for data, target in train_loader:
            if self.cuda:
                data = data.cuda(self.device)
                target = target.cuda(self.device)
            
            output = self.model(data)
            # list_broadcast_out_prob.extend(output.tolist())
            list_broadcast_out_prob.extend(nn.Softmax(dim=1)(output).tolist())
            _, predicted = torch.max(output, 1)
            list_broadcast_pred.extend(predicted.tolist())
            list_broadcast_gold.extend(target.tolist())

            # # DCA
            # init_output = torch.softmax(output, dim=1)
            # init_conf, init_pred_labels = torch.max(init_output, dim = 1)
            # init_calib_loss = torch.abs(init_conf.mean() -  (init_pred_labels == target).float().mean())
            # if initial_add == None:
            #     initial_add = init_calib_loss.item()
            # else:
            #     initial_add += init_calib_loss.item()
            
            # # MDCA
            init_output = torch.softmax(output, dim=1)
            batch, classes = output.shape
            for c in range(classes):
                avg_count = (target == c).float().mean()
                avg_conf = torch.mean(output[:,c])
                if initial_add == None:
                    initial_add = torch.abs(avg_conf - avg_count).item()
                else:
                    initial_add += torch.abs(avg_conf - avg_count).item()
        denom = classes
        initial_add /= denom


        
        big_list_out_prob.append(list_broadcast_out_prob)
        big_list_gold.append(list_broadcast_gold)
        big_list_pred.append(list_broadcast_pred)
        ########
        
        
        if self.L[id] is None:
            self.L[id] = torch.zeros_like(model_parameters)

        L_t = self.L[id]
        frz_parameters = model_parameters

        self.set_model(model_parameters)
        self._model.train()

        for _ in range(self.epochs):
            list_out_prob = []
            list_gold = []
            list_pred = []
            for data, target in train_loader:
                if self.cuda:
                    data = data.cuda(self.device)
                    target = target.cuda(self.device)
                output = self.model(data)
                _, predicted = torch.max(output, 1) # this
                # list_out_prob.extend(output.tolist())
                list_out_prob.extend(nn.Softmax(dim=1)(output).tolist()) # this
                list_gold.extend(target.tolist())
                list_pred.extend(predicted.tolist())

                # l1 = CrossEntropy().forward(output, target)
                # l1 = FocalLoss().forward(output, target) # need to tune gamma
                # l1 = LabelSmoothingLoss().forward(output, target) # need to tune alpha
                # l1 = ClassficationAndMDCA().forward(output, target)
                # l1 = BrierScore().forward(output, target)
                # l1 = DCA().forward(output, target)
                # l1 = MMCE().forward(output, target)
                # l1 = FLSD().forward(output, target)
                # l1 = LogitMarginL1().forward(output, target)

                # l1 = addinitial_DCA().forward(output, target, initial_add)
                # l1 = addinitial_ClassficationAndMDCA().forward(output, target, initial_add)

                ###
                cos = torch.nn.CosineSimilarity(dim=0)
                params = []
                for param in self.model.parameters():
                    params.append(param.view(-1))
                params = torch.cat(params)

                with torch.no_grad():
                    cos_sim = cos(params, params_init).detach().item()
                score = abs(cos_sim)


                # l1 = Inverse_DCA().forward(output, target, initial_add, score)
                # l1 = Balance_DCA().forward(output, target, initial_add, score)
                # l1 = Inverse_ClassficationAndMDCA().forward(output, target, initial_add, score)
                l1 = Balance_ClassficationAndMDCA().forward(output, target, initial_add, score)


                # l1 = self.criterion(output, target)
                l2 = torch.dot(L_t, self.model_parameters)
                l3 = torch.sum(torch.pow(self.model_parameters - frz_parameters,2))
                
                loss = l1 - l2 + 0.5 * self.alpha * l3

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            big_list_out_prob.append(list_out_prob)
            big_list_gold.append(list_gold)
            big_list_pred.append(list_pred)

        self.L[id] = L_t - self.alpha * (self.model_parameters-frz_parameters)

        return [self.model_parameters], big_list_gold, big_list_pred, big_list_out_prob
