import copy
from typing import List, Tuple, Dict, Optional
from abc import ABC, abstractmethod

from omegaconf import DictConfig
from torch.optim import *
from torch.utils.data import DataLoader
from src.algo.fed_clients.base_client import Client
from src.models import Model
from src.utils import MeasureMeter, TrainAnalyzer, ChainedAnalyzer
import logging
log = logging.getLogger(__name__)


class CenterServer(ABC):
    """ Base (abstract) class for a center server in any FL algorithm """

    def __init__(self, model: Model, dataloader: DataLoader, device: str,
                 optim: DictConfig, analyzer: Optional[TrainAnalyzer] = None):
        self._model = model.to(device)
        self._dataloader = dataloader
        self._device = device
        self._analyzer = analyzer or ChainedAnalyzer.empty()
        self._measure_meter = MeasureMeter(model.num_classes)
        optimizer_class = eval(optim.classname)
        optimizer_args = optim.args
        self._opt: Optimizer = optimizer_class(self.model.parameters(), **optimizer_args)
        # set of center server state dict keys that are required (Error if missing)
        self._state_dict_keys_required = {'model'}
        # set of center server state dict keys that are optional (Warning if missing)
        self._state_dict_keys_optional = {'opt_state'}

    @property
    def device(self):
        return self._device

    @device.setter
    def device(self, device: str):
        self._model.to(device)
        self._device = device

    @property
    def model(self):
        return self._model

    @abstractmethod
    def aggregation(self, clients: List[Client], aggregation_weights: List[float], s_round: int):
        """
        Aggregate the client's data according to their weights

        Parameters
        ----------
        clients
            the clients whose data have to be aggregated
        aggregation_weights
            the weights corresponding to clients
        s_round
            the current round of the server
        """
        pass

    def send_data(self) -> dict:
        """
        Sends out the current data of the central server. To be used to send current round data to FL clients.
        For any specific FL algorithm, (CenterServer) send_data must output the data needed by the specific
        client receive_data

        Returns
        -------
        a dictionary containing the current data of the center server
        """
        # This is the right place where to put code to keep track of the amount of exchanged data,
        # for examples using the proper analyzer listening to the proper event, like
        # self._analyzer('send_data', data=data, from=self)
        return {"model": copy.deepcopy(self._model)}

    @abstractmethod
    def validation(self, loss_fn) -> Tuple[float, MeasureMeter]:
        """
        Validates the center server model

        Parameters
        ----------
        loss_fn the loss function to be used

        Returns
        -------
        a tuple containing the value of the loss function and a reference to the center server MeasureMeter object
        """
        pass

    def state_dict(self) -> dict:
        """
        Saves the state of the center server in order to restore it when reloading a checkpoint

        Returns
        -------
        a dict with key-value pairs corresponding to the parameter name and its value

        """
        return {"model": self._model.state_dict(),
                "opt_state": self._opt.state_dict()}

    def load_state_dict(self, state: dict) -> None:
        """
        Loads a previously saved state for the center server
        Parameters
        ----------
        state
            a dictionary containing key-value pairs corresponding to the parameter name and its value
        """
        missing_required_keys = []
        for p in self._state_dict_keys_required:
            if p not in state:
                missing_required_keys.append(p)
        assert len(missing_required_keys) == 0, f"Missing params for center server: {missing_required_keys}"

        missing_optional_keys = []
        for p in self._state_dict_keys_optional:
            if p not in state:
                missing_optional_keys.append(p)
        if missing_optional_keys:
            log.warning(f"Missing optional keys: {missing_optional_keys}")

        self._model.load_state_dict(state["model"], strict=True)
        if 'opt_state' in state:
            self._opt.load_state_dict(state["opt_state"])
