
from .fedavg import SubsetEvaluation
from .caldas import SubsetDevice, CaldasServer
import torch
import copy
import logging
import numpy as np

import json

from visualization.state_dict import display_state_dict, state_dict_dif


class SLTServer(CaldasServer):
    _device_class = SubsetDevice
    _device_evaluation_class = SubsetEvaluation
    last_indices=None

    #!overrides
    def initialize(self):
        super().initialize()
        self._measurements_dict['level'] = []
        with open(self.configs_path, 'r') as fd:
            self.configs = json.load(fd)[str(min(self.scale_factor_list))]
        self.switching_idx = 0
        
        #Calculate switching points based on total rounds and params
        values = self.configs['values']
        trainable_parameters_list = []
        trainable_parameters = 0
        for value in values:
            model, _, freeze_dict = self.extract_fnc(value[0], value[1], self._global_model)
    
            #drop parameters that are already frozen
            for key in freeze_dict:
               model.pop(key)
            trainable_parameters = self.count_data_footprint(model)
            trainable_parameters_list.append(trainable_parameters)

        trainable_parameters = np.array(trainable_parameters_list)
        trainable_parameters = trainable_parameters/np.max(trainable_parameters)

        self.ranges = np.cumsum(trainable_parameters)
        self.ranges = self.ranges/np.max(self.ranges)

        #pretraining rounds
        self._pretraining_rounds = int(values[0][0]*self.n_rounds)
        logging.info(f'[Flexible]: Pretraining rounds: {self._pretraining_rounds}')
        self.ranges = np.array(self.ranges*(self.n_rounds - self._pretraining_rounds), dtype=int).tolist()

        logging.info(f'[Flexible]: Calculated Switching Points {self.ranges}')
        
        pass


    #!overrides
    def pre_round(self, round_n, rng):
        rand_selection_idxs = self.random_device_selection(self.n_devices, self.n_active_devices, rng)

        val = self.configs['values']
        ranges = copy.deepcopy(self.ranges)

        val = [[val[0][0], 0.0]] + val
        for idx, item in enumerate(ranges):
            ranges[idx] += self._pretraining_rounds
        ranges = [self._pretraining_rounds] + ranges

        if len(val) > len (ranges):
            raise ValueError(f"ranges are of length {len(ranges)}, but values are {len(val)}")
        
        # Determine current configuration
        for ranges_idx, _ in enumerate(ranges):
            if round_n < ranges[ranges_idx]:
                    break
            ranges_idx = min(ranges_idx, len(val) -1)

        # Extraction of trainable model out of the full model
        device_model_list = []
        list_of_indices_dict = []
        frozen_list = []

        for _, (dev_index) in enumerate(rand_selection_idxs):

            #Determin training depth
            scale_factor = self.scale_factor_list[dev_index]
            if scale_factor == min(self.scale_factor_list):
                training_depth = 1
            elif ranges_idx == 0:
                training_depth = 1
            else:
                training_depth = self.configs[str(scale_factor)]['freeze_values'][ranges_idx - 1]

            device_model, indices_dict, frozen = self.extract_fnc(val[ranges_idx][0], val[ranges_idx][1],
                                                                    self._global_model, training_depth=training_depth)
            device_model.update({'frozen': copy.deepcopy(frozen)})
            self.values = val[ranges_idx]

            device_model_list.append(device_model)
            list_of_indices_dict.append(indices_dict)
            frozen_list.append(frozen)
        
        if ranges_idx != self.switching_idx:
            logging.info(f"[SLT]: Switching model at round {round_n} [{val[ranges_idx][0]} {val[ranges_idx][1]}]")

        self.switching_idx = ranges_idx
        self._list_of_indices_dict = list_of_indices_dict
        self._frozen_list = frozen_list
        return rand_selection_idxs, device_model_list
    
    #!overrides
    def measure_data_upload(self, round_n, idxs):
        used_devices = [self._devices_list[i] for i in idxs]
        
        # Couting device upload
        byte_count = 0
        for idx, device in enumerate(used_devices):
            state_dict = device.get_model_state_dict()
            for key in state_dict:
                if any(key.startswith(k) for k in self._frozen_list[idx]):
                    continue
                param = state_dict[key]
                if isinstance(param, torch.Tensor):
                    val = 4
                    for i in range(len(param.shape)):
                        val *= param.shape[i]
                byte_count += val
        self._measurements_dict['data_upload'].append([byte_count, round_n])

    #!overrides
    def measure_accuracy(self, round_n):
        # Evaluation of averaged global model
        self._evaluation_device.init_model()
        self._evaluation_device.set_model_state_dict(self.extract_fnc(self.values[0], self.values[1], copy.deepcopy(self._global_model))[0] , strict=False)

        self._evaluation_device.test()
        acc = round(float(self._evaluation_device._accuracy_test), 4)
        logging.info(f"[FedRolex]: Round: {round_n} Test accuracy: {acc}")
        self._measurements_dict['accuracy'].append([acc, round_n])

        #add Switching level
        self._measurements_dict['level'].append([self.switching_idx, round_n])

    #!overrides
    def post_round(self, round_n, idxs):
        used_devices = [self._devices_list[i] for i in idxs]

        # Reference Model for averaging (from evaluation device)
        eval_model = self._global_model

        # DEBUG code
        DEBUG_old = copy.deepcopy(eval_model)

        # Extract individual Device models with stored indices
        device_models = []
        device_masks = []

        for idx, device in enumerate(used_devices):
            model, mask = self.embedd_submodel(self._list_of_indices_dict[idx], device.get_model_state_dict(), eval_model)
            device_models.append(model)
            device_masks.append(mask)

        # Model Averaging based on extracted local models
        averaged_model = self.model_averaging(device_models, device_masks,
                                              eval_device_dict=eval_model, storage_path=self._storage_path)

        # Setting new global model
        self._global_model = averaged_model
        
        # DEBUG code
        display_state_dict(state_dict_dif(copy.deepcopy(averaged_model), copy.deepcopy(DEBUG_old)), self._storage_path + f'/state_dict_viz.png')
        
        pass