from __future__ import annotations
import pickle
from pathlib import Path
from typing import Any, List, Optional
from tsbench.config import Config, DatasetConfig, ModelConfig
from tsbench.experiments import aws
from tsbench.experiments.aws import Analysis, TrainingJob
from tsbench.experiments.metrics import Performance
from tsbench.forecasts.quantile import QuantileForecasts
from .info import extract_job_infos, ValidationMetric
from .results import ExperimentResults


class Tracker:
    """
    The tracker may be used to obtain the performance metrics from a set of experiments that were
    scheduled on AWS Sagemaker.
    """

    @classmethod
    def for_experiment(cls, name: str, force_refresh: bool = False, **kwargs: Any) -> Tracker:
        """
        Loads the data associated with the experiment of the given name and caches the tracker.
        Thus, when called multiple times, it does not need to download data again.

        Args:
            name: The name of the experiment.
            force_refresh: Whether to download experiment data even if a tracker is available
                locally. This ensures that new data is fetched.

        Returns:
            The tracker with all the available data.
        """
        # Generate the filename including all kwargs
        kwargs_suffix = "-".join(f"{k}_{v}" for k, v in sorted(kwargs.items(), key=lambda i: i[0]))
        if len(kwargs_suffix) > 0:
            kwargs_suffix = f"+{kwargs_suffix}"

        # If available in cache, return
        cache = Path.home() / ".cache" / "ts-bench" / f"{name}{kwargs_suffix}.pickle"
        if cache.exists() and not force_refresh:
            with cache.open("rb") as f:
                return pickle.load(f)

        # Initialize connection to AWS
        analysis = aws.Analysis(name)

        # Initialize tracker
        tracker = Tracker(analysis, **kwargs)

        # Cache tracker and return
        cache.parent.mkdir(parents=True, exist_ok=True)
        with cache.open("wb+") as f:
            pickle.dump(tracker, f)
        return tracker

    def __init__(
        self,
        analysis: Analysis,
        validation_metric: Optional[ValidationMetric] = "ncrps",
        group_seeds: bool = True,
    ):
        """
        Args:
            analysis: The analysis object to use for obtaining Sagemaker training jobs.
            validation_metric: The metric that should be used to choose models from different
                checkpoints. If set to `None`, models are not loaded from checkpoints.
            group_seeds: Whether the same configuration with differing seeds should be grouped.
        """
        assert all(job.status == "Completed" for job in analysis), "Not all jobs have completed."

        self.infos = extract_job_infos(analysis, validation_metric, group_seeds)
        self.config_map = {info.config: info for info in self.infos}

    def get_results(self) -> ExperimentResults:
        """
        Returns all results from the experiments.
        """
        configurations = [info.config for info in self.infos]
        performances = [info.performance for info in self.infos]
        return ExperimentResults(configurations, performances)

    def unique_model_configs(self, dataset: Optional[DatasetConfig] = None) -> List[ModelConfig]:
        """
        Returns the unique model configurations that are available in the experiments managed by
        this tracker.

        Args:
            dataset: An optional dataset which limits the returned configurations to be available
                for this dataset.

        Returns:
            The list of available model configurations.
        """
        return list(
            {c.model for c in self.config_map.keys() if dataset is None or c.dataset == dataset}
        )

    def get_training_jobs(self, config: Config) -> List[TrainingJob]:
        """
        Returns all training jobs associated with the provided configuration.

        Args:
            config: The model and dataset configuration.

        Returns:
            The list of all training jobs.
        """
        return self.config_map[config].jobs

    def get_forecasts(self, config: Config) -> List[QuantileForecasts]:
        """
        Returns the quantile forecasts of all models associated with the provided configuration,
        i.e. forecasts for the same model trained on different seeds.

        Args:
            config: The configuration to obtain forecasts for.

        Returns:
            The list of forecasts for all models.
        """
        info = self.config_map[config]
        result = []
        for i, job in enumerate(info.jobs):
            with job.artifact as artifact:
                forecasts = QuantileForecasts.load(
                    artifact.path
                    / "predictions"
                    / f"model_{info.model_indices[i]}"
                    / "forecasts.npz"
                )
                result.append(forecasts)
        return result

    def get_performance(self, config: Config) -> Performance:
        """
        Returns the performance metrics for the jobs associated with the provided configuration.

        Args:
            config: The configuration object.

        Returns:
            The performance metrics.
        """
        return self.config_map[config].performance

    def __contains__(self, config: Config) -> bool:
        return config in self.config_map
