from .fl_base import FLServerBase, FLDeviceBase

import torch
import copy
import logging

class FedAvgDevice(FLDeviceBase):
    def _round(self):
        self._check_trainable()
        self._train()


class FedAvgEvaluation(FLDeviceBase):
    def test(self):
        self._check_testable()
        self._test()
        self._model.to('cpu')
        return self

    def _round(self):
        raise NotImplementedError("Evaluation Device does not have a round function.")


class FedAvgSever(FLServerBase):

    _device_evaluation_class = FedAvgEvaluation
    _device_class = FedAvgDevice

    @staticmethod
    def model_averaging(list_of_state_dicts, eval_device_dict=None):
        averaging_exceptions = ['num_batches_tracked', 'frozen']

        averaged_dict = copy.deepcopy(list_of_state_dicts[0])
        for key in averaged_dict:
            if all(module_name not in key for module_name in averaging_exceptions):
                averaged_dict[key] = torch.mean(torch.stack([state_dict[key]
                                        for state_dict in list_of_state_dicts]),dim=0)

        averaged_dict = {k: v for k, v in averaged_dict.items() if all(module_name not in k for module_name in averaging_exceptions)}
        return averaged_dict

    def set_model(self, model_list, kwargs_list):
        assert all(x==model_list[0] for x in model_list), "FedAvg requires all NN models to have the same type"
        return super().set_model(model_list, kwargs_list)

    def pre_round(self, round_n, rng):
        rand_selection_idxs = self.random_device_selection(self.n_devices, self.n_active_devices, rng)
        return rand_selection_idxs

    def post_round(self, round_n, idxs):
        used_devices = [self._devices_list[i] for i in idxs]

        byte_count = 0
        for device in used_devices:
            byte_count += self.count_data_footprint(device.get_model_state_dict())

        self._measurements_dict['data_upload'].append([byte_count, round_n])

        averaged_model = self.model_averaging([dev.get_model_state_dict() for dev in used_devices],
                                                eval_device_dict=self._global_model)

        self._evaluation_device.set_model_state_dict(copy.deepcopy(averaged_model), strict=False)
        self._evaluation_device.test()
        acc = round(float(self._evaluation_device._accuracy_test), 4)
        logging.info(f"[FEDAVG]: Round: {round_n} Test accuracy: {acc}")

        self._measurements_dict['accuracy'].append([acc, round_n])
        self._main_performance_metric = acc

        self._global_model = averaged_model        
