from .fl_base import FLServerBase, FLDeviceBase

import torch
import copy
import logging
from visualization.state_dict import display_state_dict, state_dict_dif


class FedAvgDevice(FLDeviceBase):
    def _round(self):
        self._check_trainable()
        self._train()


class FedAvgEvaluation(FLDeviceBase):
    def test(self):
        self._check_testable()
        self._test()
        self._model.to('cpu')
        return self

    def _round(self):
        raise NotImplementedError("Evaluation Device does not have a round function.")

class SubsetEvaluation(FLDeviceBase):
    scale_factor = -1.0
    def test(self):
        self._check_testable()
        self._test()
        self._model.to('cpu')
        return self

    def _round(self):
        raise NotImplementedError("Evaluation Device does not have a round function.")


class FedAvgSever(FLServerBase):

    _device_evaluation_class = SubsetEvaluation
    _device_class = FedAvgDevice
    scale_factor = -1.0

    @staticmethod
    def model_averaging(list_of_state_dicts, eval_device_dict=None):
        averaging_exceptions = ['num_batches_tracked']

        averaged_dict = copy.deepcopy(list_of_state_dicts[0])
        for key in averaged_dict:
            if all(module_name not in key for module_name in averaging_exceptions):
                averaged_dict[key] = torch.mean(torch.stack([state_dict[key]
                                        for state_dict in list_of_state_dicts]),dim=0)

        averaged_dict = {k: v for k, v in averaged_dict.items() if all(module_name not in k for module_name in averaging_exceptions)}
        return averaged_dict

    def set_model(self, model_list, kwargs_list):
        assert all(x==model_list[0] for x in model_list), "FedAvg requires all NN models to have the same type"
        return super().set_model(model_list, kwargs_list)

    def pre_round(self, round_n, rng):
        rand_selection_idxs = self.random_device_selection(self.n_devices, self.n_active_devices, rng)
        return rand_selection_idxs, [self._global_model for _ in range(self.n_active_devices)]

    def post_round(self, round_n, idxs):
        used_devices = [self._devices_list[i] for i in idxs]

        # DEBUG code
        DEBUG_old = copy.deepcopy(self._global_model)

        averaged_model = self.model_averaging([dev.get_model_state_dict() for dev in used_devices],
                                                eval_device_dict=self._global_model)

        self._global_model = averaged_model
        display_state_dict(state_dict_dif(copy.deepcopy(averaged_model), copy.deepcopy(DEBUG_old)), self._storage_path + f'/state_dict_viz.png')    
