
from .fedavg import FedAvgSever, FedAvgEvaluation
from .fl_base import FLDeviceBase
import torch
import copy
import logging

class HeteroFLDevice(FLDeviceBase):

    #!overrides
    def _round(self):
        self._check_trainable()
        
        kwargs = copy.deepcopy(self._model_kwargs)

        kwargs.update({'keep_factor': self._model_class.get_keep_factor(self.resources)})
        logging.info(f'[HETEROFL]: Resources: {self.resources}')
        logging.info(f"[HETEROFL]: keep_factor {kwargs['keep_factor']}")
        
        #reinitialize model and restore state_dict
        state_dict = self._model.state_dict()
        self._model = self._model_class(**kwargs)
        self.set_model_state_dict(state_dict, strict=False)

        #train heterofl network
        self._train()

        return


class HeteroFLServer(FedAvgSever):
    _device_class = HeteroFLDevice
    _device_evaluation_class = FedAvgEvaluation
    drop_weakest = False

    #!overrides
    @staticmethod
    def model_averaging(list_of_state_dicts, eval_device_dict=None):
        averaging_exceptions = []

        if eval_device_dict is not None:
            averaged_dict = copy.deepcopy(eval_device_dict)
        else:
            averaged_dict = copy.deepcopy(list_of_state_dicts[0])

        for key in averaged_dict:
            if all(module_name not in key for module_name in averaging_exceptions):
                out = torch.zeros(averaged_dict[key].shape)
                div_mask = torch.zeros(averaged_dict[key].shape)

                for state_dict in list_of_state_dicts:
                    weight = state_dict[key]

                    if len(averaged_dict[key].shape) >= 2:
                        out[0:weight.shape[0],0:weight.shape[1]] += weight
                        div_mask[0:weight.shape[0],0:weight.shape[1]] += torch.ones(weight.shape)
                    else:
                        out[0:weight.shape[0]] += weight
                        div_mask[0:weight.shape[0]] += torch.ones(weight.shape)

                out[div_mask > 0] = out[div_mask > 0]/div_mask[div_mask > 0]
                averaged_dict[key][div_mask > 0] = out[div_mask > 0]
        
        return averaged_dict

    def pre_round(self, round_n, rng):
        if self.drop_weakest == False:
            return super().pre_round(round_n, rng)
        else:
            rand_selection_idxs_all = self.random_device_selection(self.n_devices, self.n_devices, rng)
            selection_idxs = [idx for idx in rand_selection_idxs_all
                    if self._device_constraints[idx].get_memory() > 0.4] # Unique to the IMDB/STransformer combination
            assert len(selection_idxs) >= self.n_active_devices, "Error: Cant run since too many devices have to be dropped"
            return selection_idxs[0:self.n_active_devices]