from .fl_base import FLDeviceBase
from .fedavg import FedAvgEvaluation
from .heterofl import HeteroFLServer

import torch
import random
import copy
import logging

class FjordEvaluationDevice(FedAvgEvaluation):
    
    def __init__(self, device_id, storage_path):
        super().__init__(device_id, storage_path)
        self.fjord_p = None
        self.fjord_bn_dict = None

    def set_fjord_p_values(self, p):
        self.fjord_p = p

    #!overrides
    def get_model_state_dict(self):

        if self.fjord_bn_dict is None:
            #create fjord_bn_ set
            self.fjord_bn_dict = []

            for p in self.fjord_p:
                kwargs = copy.deepcopy(self._model_kwargs)
                kwargs.update({'keep_factor': p})
                model = self._model_class(**kwargs)
                sd = model.state_dict()

                sd = dict(filter(lambda x: 'bn' in x[0], sd.items()))
                self.fjord_bn_dict.append((p, sd))

        model_state_dict = super().get_model_state_dict()
        model_state_dict.update({'fjord_bn' : self.fjord_bn_dict})
        return model_state_dict

    #!overrides
    def set_model_state_dict(self, model_dict, strict=True):
        #Only the unpruned model is used for verification
        if 'fjord_bn' in model_dict:
            model_dict.pop('fjord_bn')
        return super().set_model_state_dict(model_dict, strict)


class FjordDevice(FLDeviceBase):

    def __init__(self, device_id, storage_path):
        super().__init__(device_id, storage_path)
        self.fjord_p = None
        self.fjord_bn_dict = None

    def set_fjord_p_values(self, p):
        self.fjord_p = copy.deepcopy(p)

    #!overrides
    def set_model_state_dict(self, model_dict, strict=True):
        model_dict = copy.deepcopy(model_dict)
        if 'fjord_bn' in model_dict:
            self.fjord_bn_dict = model_dict['fjord_bn']
            model_dict.pop('fjord_bn')
        super().set_model_state_dict(model_dict, strict=strict)


    #!overrides
    def get_model_state_dict(self):
        model_state_dict = super().get_model_state_dict()
        model_state_dict.update({'fjord_bn' : self.fjord_bn_dict})
        return model_state_dict

    #!overrides
    def _round(self):

        self._check_trainable()
        kwargs = copy.deepcopy(self._model_kwargs)

        maximum_keep_factor = self._model_class.get_keep_factor(self.resources)
        trainable_fjord_p = [p for p in self.fjord_p if p <= maximum_keep_factor]
        logging.info("[FJORD]: " + repr(trainable_fjord_p))
        kwargs.update({'keep_factor': max(trainable_fjord_p)})

        #reinitialize model and restore state_dict
        state_dict = self._model.state_dict()               
        self._model = self._model_class(**kwargs)
        self.set_model_state_dict(state_dict, strict=False)

        loss_function = self._loss_F()

        trainloader = torch.utils.data.DataLoader(self._train_data, batch_size=self._batch_size_train,
                                                shuffle=True, pin_memory=True)

        model = None
        trained_p_levels = []
        for epoch_n in range(1):
            for batch_idx, (inputs, labels) in enumerate(trainloader):

                if batch_idx != 0:
                    #fuse model
                    current_state_dict = model.state_dict()
                    sd = dict(filter(lambda x: 'bn' in x[0], current_state_dict.items()))
                    self.fjord_bn_dict[trainable_fjord_p.index(p)] = (p, copy.deepcopy(sd))
                    self._model.load_state_dict(self.fuse_state_dicts(
                        self._model.state_dict(), current_state_dict))

                val = [x for x in range(len(trainable_fjord_p))]
                random.shuffle(val)
                p = trainable_fjord_p[val[0]]
                kwargs.update({'keep_factor': p})
                trained_p_levels.append(p)

                model = self._model_class(**kwargs)
                model.load_state_dict(self._model.state_dict())
                model.load_state_dict(self.fjord_bn_dict[trainable_fjord_p.index(p)][1], strict=False)
                model.to(self._torch_device)
                model.train()

                optimizer = self._optimizer(model.parameters(), lr=self.lr, **self._optimizer_args)

                inputs, labels = inputs.to(self._torch_device), labels.to(self._torch_device)
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_function(output, labels)
                self.assert_if_nan(loss)
                loss.backward()
                optimizer.step()


        current_state_dict = model.state_dict()
        sd = dict(filter(lambda x: 'bn' in x[0], current_state_dict.items()))
        self.fjord_bn_dict[trainable_fjord_p.index(p)] = (p, copy.deepcopy(sd))
        self._model.load_state_dict(self.fuse_state_dicts(
                            self._model.state_dict(), current_state_dict))

        #update the evaluation model in the case it got trained
        if maximum_keep_factor == 1.0:
            self._model.load_state_dict(self.fjord_bn_dict[trainable_fjord_p.index(1.0)][1], strict=False)
        
        #drop fjord_bn levels in case that they are not updated
        self.fjord_bn_dict = list(filter(lambda x: x[0] in trained_p_levels, self.fjord_bn_dict))

        #to cpu
        for (p, item) in self.fjord_bn_dict:
            for key in item:
                item[key] = item[key].cpu()
                pass

        return

    @staticmethod
    def fuse_state_dicts(max_dict, small_dict):
        for key in small_dict:
            if 'bn' in key:
                continue
            else:
                item = small_dict[key]
                if len(item.size()) == 4:
                    max_dict[key][0:item.shape[0], 0:item.shape[1],:,:] = small_dict[key]
                elif len(item.size()) == 2:
                    max_dict[key][0:item.shape[0], 0:item.shape[1]] = small_dict[key]
                elif len(item.size()) == 1:
                    max_dict[key][:item.shape[0]] = small_dict[key]
                else:
                    raise NotImplementedError
        return max_dict

    
class FjordServer(HeteroFLServer):
    _device_class = FjordDevice
    _device_evaluation_class = FjordEvaluationDevice

    def set_fjord_p_values(self, p):
        self._p_values = p

    #!overrides
    def initialize(self):
        super().initialize()
        for dev in self._devices_list:
            dev.set_fjord_p_values(self._p_values)

        #reinitiualized fjord evaluation device
        self._evaluation_device.set_model(self._model[0], **self._model_args[0])
        self._evaluation_device.init_model()
        self._evaluation_device.set_fjord_p_values(self._p_values)
        self._global_model = self._evaluation_device.get_model_state_dict()
        return

        #!overrides
    @staticmethod
    def model_averaging(list_of_state_dicts, eval_device_dict=None):
        averaging_exceptions = ['fjord_bn']

        if eval_device_dict is not None:
            averaged_dict = copy.deepcopy(eval_device_dict)
        else:
            averaged_dict = copy.deepcopy(list_of_state_dicts[0])

        #HeteroFL based averaging of normal parameters
        for key in averaged_dict:
            if all(module_name not in key for module_name in averaging_exceptions):
                out = torch.zeros(averaged_dict[key].shape)
                div_mask = torch.zeros(averaged_dict[key].shape)

                for state_dict in list_of_state_dicts:
                    weight = state_dict[key]

                    if len(averaged_dict[key].shape) >= 2:
                        out[0:weight.shape[0],0:weight.shape[1]] += weight
                        div_mask[0:weight.shape[0],0:weight.shape[1]] += torch.ones(weight.shape)
                    else:
                        out[0:weight.shape[0]] += weight
                        div_mask[0:weight.shape[0]] += torch.ones(weight.shape)

                out[div_mask > 0] = out[div_mask > 0]/div_mask[div_mask > 0]
                averaged_dict[key][div_mask > 0] = out[div_mask > 0]
        
        #FJORD averaging of batchnorm (complexity-level wise)
        fjord_bn_dict_list = []
        
        for item in list_of_state_dicts:
            fjord_bn_dict_list.append(item['fjord_bn'])
        
        avg_fjord_bn = copy.deepcopy(averaged_dict['fjord_bn'])

        for _, (p, level) in enumerate(avg_fjord_bn):
            filtered_levels = []
            for i in range(len(fjord_bn_dict_list)):
                for j in range(len(fjord_bn_dict_list[i])):
                    if fjord_bn_dict_list[i][j][0] == p:
                        filtered_levels.append(fjord_bn_dict_list[i][j][1])

            if len(filtered_levels) != 0:
                for key in level:
                    level[key] = torch.mean(torch.stack([item[key] for item in filtered_levels]), axis=0)

        averaged_dict.update({'fjord_bn' : avg_fjord_bn})

        #make sure that evaluation device uses the BN values from p=1.0 configuration
        for key in averaged_dict['fjord_bn'][-1][1]:
            averaged_dict[key] = averaged_dict['fjord_bn'][-1][1][key]

        return averaged_dict