import copy
from abc import ABC, abstractmethod
from typing import Dict, Optional, List
import torch.nn
from torch.utils.data import DataLoader
import tensorboardX
import logging

log = logging.getLogger(__name__)


class TrainAnalyzer(ABC):
    """ Base (abstract) for any analyzer """

    def __init__(self, call_args: Dict[str, type], event: str, verbose: bool = False,
                 writer: Optional[tensorboardX.writer.SummaryWriter] = None):
        """

        Parameters
        ----------
        call_args mandatory arguments for a call to analyze()
        event describes what request the analyzer responds to
        verbose outputs additional information about the analyzer
        writer tensorboardX SummaryWriter instance to log results to
        """
        self._event = event
        self.__args = call_args
        self._verbose = verbose
        self._writer = writer
        self._result = {}

    @property
    def result(self) -> dict:
        return self._result

    def reset(self) -> None:
        self._result.clear()

    def __call__(self, event: str, *args, **kwargs) -> None:
        assert len(args) == 0, "Only named parameters accepted"
        self.__verify_args(kwargs)
        if self.listen_to_event(event):
            return self._analyze(event, **kwargs)

    def __verify_args(self, kwargs):
        assert all([p in kwargs for p in self.__args]), \
            f"{self.__class__.__name__}: missing parameters: given {kwargs.keys()}, required {self.__args.keys()}"
        for arg_name, arg_type in self.__args.items():
            assert isinstance(kwargs[arg_name], arg_type), f"Parameter {arg_name} expected to be of type {arg_type}, " \
                                                           f"instead is of type {kwargs[arg_name]}"

    @abstractmethod
    def _analyze(self, event, **kwargs) -> None:
        """
        Perform analysis on the given keyword arguments in response to an event

        Parameters
        ----------
        event describe the request made to the analyzer
        kwargs arguments to use to perform analysis
        """
        pass

    def state_dict(self) -> dict:
        return {'classname': self.__class__.__name__, 'result': self._result}

    def load_state_dict(self, state: dict) -> None:
        assert 'classname' in state and 'result' in state, "Incomplete state for analyzer"
        if state['classname'] != self.__class__.__name__:
            log.warning(f'Reloading results from different analyzer class, expected {self.__class__.__name__}, '
                        f'given {state["classname"]}')
        self._result = copy.deepcopy(state['result'])

    def listen_to_event(self, event) -> bool:
        return event == self._event


class EmptyAnalyzer(TrainAnalyzer):
    """ Dummy analyzer that does nothing"""

    def __init__(self, **kwargs):
        super().__init__({}, "")

    def _analyze(self, event, **kwargs):
        if self._verbose:
            log.info("Empty analyzer called")


class AnalyzerContainer(ABC):
    """
    Base (abstract) class for a container of analyzers
    """

    def __init__(self, analyzers):
        self._analyzers: Dict[str, TrainAnalyzer] = analyzers
        self._old_state_dict = {}

    def add_analyzer(self, analyzer: TrainAnalyzer):
        self._analyzers.update({analyzer.__class__.__name__: analyzer})

    def contains_analyzer(self, classname) -> bool:
        return classname in self._analyzers

    def state_dict(self) -> dict:
        state = copy.deepcopy(self._old_state_dict)
        for name, analyzer in self._analyzers.items():
            state[name] = analyzer.state_dict()
        return state

    def load_state_dict(self, state: dict) -> None:
        if not all([name in state for name in self._analyzers]):
            log.warning("Missing states for some analyzers")
        if any([name not in self._analyzers for name in state]):
            log.warning("Found analyzers in previous run not instantiated for this run")
        for name, analyzer_state in state.items():
            if name not in self._analyzers:
                self._old_state_dict[name] = analyzer_state
            else:
                self._analyzers[name].load_state_dict(analyzer_state)

    def reset(self) -> None:
        for analyzer in self._analyzers.values():
            analyzer.reset()


class ChainedAnalyzer(AnalyzerContainer, TrainAnalyzer):
    """
    Container of analyzers that applies each analyzer sequentially
    """

    def __init__(self, analyzers: List[dict], verbose: bool = False,
                 writer: Optional[tensorboardX.writer.SummaryWriter] = None):
        # check format
        assert all(['classname' in a and 'args' in a for a in analyzers]), "Error in format for analyzers"
        analyzers: Dict[str, TrainAnalyzer] = {a['classname']: eval(a['classname'])(**a['args'], writer=writer,
                                                                                    verbose=verbose)
                                               for a in analyzers}
        AnalyzerContainer.__init__(self, analyzers)
        TrainAnalyzer.__init__(self, {}, "", verbose, writer)

    @staticmethod
    def empty():
        return ChainedAnalyzer([{'classname': 'EmptyAnalyzer', 'args': {}}])

    def _analyze(self, event, **kwargs) -> None:
        for name, analyzer in self._analyzers.items():
            if analyzer.listen_to_event(event):
                analyzer(event, **kwargs)
                self._result.update({name: analyzer.result})

    def reset(self) -> None:
        AnalyzerContainer.reset(self)
        TrainAnalyzer.reset(self)

    def listen_to_event(self, event) -> bool:
        listeners = [a.listen_to_event(event) for a in self._analyzers.values()]
        return any(listeners)


class AnalyzerController(AnalyzerContainer):
    """
    Simple container that associates an analyzer container to each module
    """

    def __init__(self, analyzers: dict, writer: Optional[tensorboardX.writer.SummaryWriter] = None):
        verbose = analyzers['verbose']
        modules = {component_name: ChainedAnalyzer(component_analyzer, verbose, writer)
                   for component_name, component_analyzer in analyzers['modules'].items()}
        super().__init__(modules)
        if verbose:
            log.info(f'Components of AnalyzerContainer:\n{analyzers["modules"]}')

    @property
    def result(self) -> dict:
        result = {}
        for module, chained_analyzer in self._analyzers.items():
            previous_result = result.get(module, {})
            new_result = chained_analyzer.result
            previous_result.update(new_result)
            result.update({module: previous_result})
        return result

    def module_analyzer(self, module: str) -> ChainedAnalyzer:
        return self._analyzers.get(module, EmptyAnalyzer())

    def add_analyzer(self, analyzer: TrainAnalyzer):
        assert isinstance(analyzer, ChainedAnalyzer), "Modules must have a ChainedAnalyzer, not plain analyzer"
        super(AnalyzerController, self).add_analyzer(analyzer)


class ServerAnalyzer(TrainAnalyzer):
    """
    Analyzer for a center server
    """

    def __init__(self, val_period: int, val_always_last_rounds: int, total_rounds: int, *args, **kwargs):
        from src.algo.center_server import CenterServer
        super().__init__({'server': CenterServer, 'loss_fn': torch.nn.Module, 's_round': int}, *args, **kwargs)
        self._val_period = val_period
        self._val_always_last_rounds = val_always_last_rounds
        self._total_rounds = total_rounds

    def _analyze(self, event, server, loss_fn, s_round, other_scalars: dict = {}, **kwargs) -> None:
        from torch.nn import CrossEntropyLoss
        if s_round % self._val_period == 0 or s_round > self._total_rounds - self._val_always_last_rounds:
            loss, mt = server.validation(CrossEntropyLoss())
            p_norm = server.model.param_norm(2)
            scalars = other_scalars.copy()
            scalars.update({'param_norm': p_norm})
            self._log(s_round, loss, mt, scalars)
            data = {'loss': loss, 'accuracy': mt.accuracy_overall, 'accuracy_class': mt.accuracy_per_class}
            data.update(scalars)
            self._result.update({s_round: data})

    def _log(self, s_round, loss, mt, other_scalars: dict):
        log.info(
                f"[Round: {s_round: 05}] Test set: Average loss: {loss:.4f}, Accuracy: {mt.accuracy_overall:.2f}%"
        )
        if self._writer is not None:
            self._writer.add_scalar(f'{self.__class__.__name__}/val/loss', loss, s_round)
            self._writer.add_scalar(f'{self.__class__.__name__}/val/accuracy', mt.accuracy_overall, s_round)
            for tag, scalar in other_scalars.items():
                self._writer.add_scalar(f'{self.__class__.__name__}/{tag}', scalar, s_round)


class ForgettingAnalyzer(TrainAnalyzer):
    def __init__(self, superclient_ids: List[int], check_period: int, *args, **kwargs):
        from src.algo.fed_clients import Client
        super().__init__({"current_client": Client, "clients": list, 'superclient_id': int,
                          's_round': int}, *args, **kwargs)
        self.superclient_ids = superclient_ids
        self.check_period = check_period
        # result contains, for each superclient, a dictionary with id: {round: {client_id: forgetting of this client}}

    def _calculate_forgetting(self, current_client, clients) -> List[dict]:
        from torch.nn import CrossEntropyLoss
        forgetting = []
        for prev_client in clients:
            d = prev_client.dataloader
            dataset, batch_size, drop_last, num_workers = d.dataset.get_copy(
                False), d.batch_size, d.drop_last, d.num_workers
            test_dataloader = DataLoader(dataset, batch_size, False, drop_last=drop_last, num_workers=num_workers)
            loss, meter = current_client.client_evaluate(CrossEntropyLoss(), test_dataloader)
            forgetting.append({"id": prev_client.client_id, "loss": loss, "accuracy": meter.accuracy_overall})
            if prev_client.client_id == current_client.client_id:
                break
        return forgetting

    def _not_to_be_analyzed(self, clients: list, superclient_id: int, s_round: int) -> bool:
        if superclient_id not in self.superclient_ids:
            return True
        rounds_with_full_stats = map(lambda x: x[0], filter(lambda d: len(d[1]) == len(clients),
                                                            self._result.get(superclient_id, {}).items()))
        last_round = max(rounds_with_full_stats, default=-self.check_period)
        skip_this_round = (s_round - last_round) < self.check_period
        return skip_this_round

    def _analyze(self, event, current_client, clients, superclient_id, s_round, **kwargs) -> None:
        if self._not_to_be_analyzed(clients, superclient_id, s_round):
            return
        forgetting = self._calculate_forgetting(current_client, clients)
        sp_stats = self._result.get(superclient_id, {})  # stats for this superclient
        sp_round = sp_stats.get(s_round, {})  # stats for this round
        sp_round[current_client.client_id] = forgetting  # forgetting for this client inside superclient

        sp_stats[s_round] = sp_round
        self._result[superclient_id] = sp_stats
