import copy
import os.path

from .fedbase import BasicServer, BasicClient
import numpy as np
from utils import fmodule
import utils.system_simulator as ss
import torch
import collections
import json

class Server(BasicServer):
    def __init__(self, option, model, clients, test_data = None):
        super(Server, self).__init__(option, model, clients, test_data)
        # self.init_algo_para({'feda': 0.1, 'T': 0.5})
        self.save_name = str(option['seed']) + '_' + option['task'] + '_model_' + option['algorithm'] + '_lr_' + str(option['learning_rate'])

    def sample(self):
        selected_clients = list(range(20))
        print(selected_clients)
        return selected_clients
    def run(self):
        filename = 'FedIS_layer_Log.json'
        if os.path.exists(filename):
            os.remove(filename)
        for round in range(1, self.num_rounds + 1):
            self.current_round = round
            importance = self.iterate()
            test_metric, save_metric = self.test_on_clients()
            global_acc = float(self.test()['accuracy'])
            accuracy = test_metric['accuracy']
            loss = test_metric['loss']
            self.outFunc(round, global_acc, accuracy, loss)
            self.save_log(self.stream_log)
            # decay learning rate
            self.global_lr_scheduler(round)
            selected = [int(i) for i in self.selected_clients]
            # print(importance.shape)
            # assert 1 == 0
            with open(filename, 'a') as f:
                f.write('Round_{}'.format(self.current_round))
                f.write('\n')
                json.dump(selected, f)
                f.write('\n')
                a = json.dumps(importance.tolist())
                b = a.replace('],', ',\n')
                f.write(b)
                f.write('\n')
                json.dump(save_metric, f)
                f.write('\n')
                f.write('\n')
        return

    def iterate(self):
        # sample clients
        self.selected_clients = self.sample()
        # training
        res = self.communicate(self.selected_clients)
        # lasts, models, train_losses = res['last'], res['model'], res['loss']
        models, train_losses = res['model'], res['loss']
        grads = [(self.model - model) / self.lr for model in models]
        norm_grads = torch.stack([torch.stack([torch.norm(torch.flatten(param)) for _, param in grad.named_parameters()], dim=0) for grad in grads], dim=0)
        sum_norm = torch.sum(norm_grads, dim=0)
        importance = torch.div(norm_grads, sum_norm)
        # print(importance)
        with torch.no_grad():
            for i in range(len(models)):
                j = 0
                for name, param in models[i].named_parameters():
                    param *= importance[i][j]
                    j += 1
        self.model = self.aggregate(models)
        return importance

    def test_on_clients(self, dataflag='valid'):
        all_metrics = collections.defaultdict(list)
        save_metrics = collections.defaultdict(list)
        for cid, c in enumerate(self.clients):
            client_metrics = c.test(self.model, dataflag)
            for met_name, met_val in client_metrics.items():
                all_metrics[met_name].append(met_val)
            if cid in self.selected_clients:
                for met_name, met_val in client_metrics.items():
                    save_metrics[met_name].append(met_val)
        return all_metrics, save_metrics

    def aggregate(self, models):
        return fmodule._model_sum(models)

    # def aggregate(self, models: list, new_weight):
    #     if len(models) == 0: return self.model
    #     p = new_weight
    #     return fmodule._model_sum([model_k * pk for model_k, pk in zip(models, p)])
    ###############################################################################
    # (1-alpha) * w_reweight + alpha * w_ideal
    ###############################################################################

    # def aggregate(self, lasts, models: list, new_weight):
    #     if len(models) == 0: return self.model
    #     models_copy_1 = copy.deepcopy(models)
    #     w_ideal = copy.deepcopy(lasts[0])
    #     p = new_weight
    #     w_reweight = fmodule._model_sum([model_k * pk for model_k, pk in zip(models_copy_1, p)])
    #     new_model = (1.0-self.alpha)*w_reweight + self.alpha*w_ideal
    #     return new_model

class Client(BasicClient):
    def reply(self, svr_pkg):
        model = self.unpack(svr_pkg)
        train_loss = self.test(model, 'train')['loss']
        self.train(model)
        cpkg = self.pack(model, train_loss)
        return cpkg

    def pack(self, model, loss):
        return {
            "model": model,
            "loss": loss,
        }
