import pathlib
import logging
import getpass
import json
from typing import Dict, Generator, Tuple, List, Optional, Any, Union, Sequence

import coloredlogs
import torch.nn
from torch.utils.data import DataLoader
from nvsmpy import CudaCluster

from path_learning.utils.log import get_timestring, get_logger
from path_learning.utils.result import ExperimentSetResult, ExperimentResult, TaskResult
from path_learning.dataloaders.supervised_dataloader import pick_dataloader, SupervisedDataloader

ResultGeneratorType = Generator[Tuple[str, Any], None, None]


logger = get_logger("analysis_method")


class _AnalysisMethod:

    name = None

    def __init__(self, experiment_set: ExperimentSetResult, seed: int, logdir: pathlib.Path):

        self.user: str = getpass.getuser()

        self.experiment_set: ExperimentSetResult = experiment_set
        self.experiment: ExperimentResult = self.get_experiment(seed)

        self.logdir: pathlib.Path = logdir
        self.logger = self.create_logger()

        self.results = {}

        self.logger.info(f"running analysis {self.name} on experiment set "
                         f"{self.experiment_set.dir.stem} with seed {self.experiment.dir.stem}")

    def get_experiment(self, seed: int) -> ExperimentResult:
        matched_experiments = tuple(filter(lambda exp: exp.seed == seed, self.experiment_set.experiments))
        if len(matched_experiments) > 1:
            raise RuntimeError(f"Found more than one experiment with seed {seed} in "
                               f"experiment set {self.experiment_set.dir.stem}! Please remove one of the experiments.")
        else:
            return matched_experiments[0]

    def run(self):
        raise NotImplementedError("_AnalysisMethod class is abstract")

    def save_results(self, task_result: Sequence[TaskResult], analysis_results: ResultGeneratorType) -> None:
        result_name: str
        for result_name, result_values in analysis_results:
            self.save_single_result(result_name, result_values, task_result)

    def create_logger(self) -> logging.Logger:
        logger = logging.Logger(f"{self.name}_{self.experiment_set.name}")
        logger.setLevel(logging.INFO)

        coloredlogs.install(logger=logger)

        fh = logging.FileHandler(str(self.logdir / "run.log"))
        fmt_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        formatter = logging.Formatter(fmt_str)
        fh.setFormatter(formatter)
        fh.setLevel(logging.DEBUG)

        logger.addHandler(fh)
        return logger

    def save_single_result(self, result_name: str, result_values: Any,
                           task_results: Sequence[TaskResult]) -> None:
        info = {
            "result_name": result_name,
            "result_values": result_values,
            "method_name": self.name,
            "user": self.user,
            "experiment_set": {
                "dir": str(self.experiment_set.dir),
                "name": self.experiment_set.name,
                "user": self.experiment_set.user,
            },
            "experiment": {
                "dir": str(self.experiment.dir),
                "seed": self.experiment.seed
            },
            "task": [{"uid": task_result.uid, "dataset": task_result.dataset, "dir": str(task_result.dir)} for task_result in task_results]
        }

        task_uid: str = "task_" + "_".join([str(task_result.uid) for task_result in task_results])

        task_dir = self.logdir / task_uid
        task_dir.mkdir(exist_ok=True, parents=False)

        out_fname = task_dir / f"result_{result_name}_{task_uid}.json"
        assert not out_fname.is_file(), "Output file for analysis already existed!"

        try:
            with open(str(out_fname), "w") as fp:
                json.dump(info, fp, indent=4)
        except TypeError:
            self.logger.warning(f"Could not save analysis results: {info}")
            self.logger.warning("ALL RESULTS MUST BE JSON SERIALIZABLE. "
                                "HENCE THEY CANNOT BE NUMPY ARRAYS OR DTYPES. "
                                "THEY MUST BE ANY COMBINATION OF PYTHON BUILTIN DATATYPES")
            raise


class _SingleAnalysisMethod(_AnalysisMethod):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = 'cuda'

    def run(self):
        cluster = CudaCluster()
        with cluster.available_devices(max_n_processes=1):
            torch.cuda.empty_cache()
            for task_result in self.experiment.tasks:
                self.logger.info(f"config: {task_result.config}")
                model: torch.nn.Module = self.experiment.load_model(task_result.uid)
                model = model.to(self.device)
                analysis_results: ResultGeneratorType = self.analyze_model(task_result, model)
                self.save_results([task_result], analysis_results)
                model = None

    @staticmethod
    def generate_dataloader(task_result: TaskResult, purpose: str = "train") -> DataLoader:
        logger.info(f"Generate dataloader with subset: {purpose}")
        return pick_dataloader(task_result.config["domains"]["target"])[purpose]

    @staticmethod
    def generate_second_dataloader(task_result: TaskResult, purpose: str = "train") -> DataLoader:
        logger.info(f"Generate dataloader with subset: {purpose}")
        return pick_dataloader(task_result.config["domains"]["target2"])[purpose]

    def analyze_model(self, task_result: TaskResult, model: torch.nn.Module) -> ResultGeneratorType:
        raise NotImplementedError("_SingleAnalysisMethod class is abstract")


class _StepAnalysisMethod(_AnalysisMethod):
    def __init__(self, *args, **kwargs):
        """
        Evaluate model trained on task 'model_task' on task with ID 'task_id'.
        If 'model_task' is chosen to be -1, a randomly initialized model is chosen.
        :param args:
        :param kwargs:
        """
        self.model_task: int = kwargs.pop("model_task_id", 0)
        if self.model_task == -1:
            self.model_task = None
        self.task_to_evaluate: int = kwargs.pop("task_id", 0)
        super().__init__(*args, **kwargs)
        self.logger.info(f"Evaluating model trained on task {self.model_task} on task {self.task_to_evaluate}.")

    def run(self):
        for count, task_result in enumerate(self.experiment.tasks):
            self.logger.info(f"task count {count}")
            # Load model "model_task" for task "task_to_evaluate"
            if count == self.task_to_evaluate:
                model: torch.nn.Module = self.experiment.load_model(self.model_task)
                analysis_results: ResultGeneratorType = self.analyze_model(task_result, model)
                self.save_results([task_result], analysis_results)

    @staticmethod
    def generate_dataloader(task_result: TaskResult, purpose: str = "test") -> DataLoader:
        logger.info(f"Generate dataloader with subset: {purpose}")
        return pick_dataloader(task_result.config["domains"]["target"])[purpose]

    def analyze_model(self, task_result: TaskResult, model: torch.nn.Module) -> ResultGeneratorType:
        raise NotImplementedError("_SingleAnalysisMethod class is abstract")


class _PairwiseAnalysisMethod(_AnalysisMethod):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run(self):
        task_result_pairs = [(self.experiment.tasks[i], self.experiment.tasks[i+j])
                      for i in range(1) for j in range(1, len(self.experiment.tasks)-i)]
        pairs = [(i, i+j) for i in range(1) for j in range(1, len(self.experiment.tasks)-i)]
        self.logger.info(f"Task result pairs: {pairs}")
        for (task0, task1) in task_result_pairs:
            model0 = self.experiment.load_model(task0.uid)
            model1 = self.experiment.load_model(task1.uid)
            analysis_results = self.analyze_model((task0, task1), (model0, model1))
            self.save_results((task0, task1), analysis_results)

    @staticmethod
    def generate_dataloader(task0_result: TaskResult, task1_result: TaskResult, purpose: str = "test"):
        d0 = task0_result.config["domains"]["target"]
        d1 = task1_result.config["domains"]["target"]
        return pick_dataloader((d0, d1))[purpose]

    def analyze_model(self, task_results: Tuple[TaskResult, TaskResult],
                      models: Tuple[torch.nn.Module, torch.nn.Module]) -> ResultGeneratorType:
        raise NotImplementedError("_PairwiseAnalysisMethod class is abstract")
