import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, List, Literal, Optional
import numpy as np
from tqdm.auto import tqdm
from tsbench.config import Config, MODEL_REGISTRY, TrainConfig
from tsbench.config.dataset import get_dataset_config
from tsbench.config.model import get_model_config
from tsbench.experiments.aws import Analysis, TrainingJob
from tsbench.experiments.metrics import Metric, Performance

ValidationMetric = Literal["loss", "ncrps"]


@dataclass
class JobInfo:
    """
    The job info class aggregates all information available for a particular model configuration.
    It also provides the underlying training jobs. Lastly, it provides the indices of the models
    chosen from the training job to extract forecasts.
    """

    config: Config
    performance: Performance
    jobs: List[TrainingJob]
    model_indices: List[int]


# -------------------------------------------------------------------------------------------------


def extract_job_infos(
    analysis: Analysis, validation_metric: Optional[ValidationMetric], group_seeds: bool
) -> List[JobInfo]:
    """
    Returns a list of the job information objects available for all training runs.
    """
    # We group the jobs by hyperparameters, excluding the seed
    grouped_jobs = defaultdict(list)
    for job in analysis:
        hypers = {**job.hyperparameters}
        if group_seeds:
            del hypers["seed"]
        grouped_jobs[tuple(sorted(hypers.items()))].append(job)

    # Then, we can instantiate the info objects by iterating over groups of jobs
    runs = []
    for jobs in tqdm(grouped_jobs.values()):
        ref_job = jobs[0]
        model_name = ref_job.hyperparameters["model"]

        base_hyperparams = {**ref_job.hyperparameters}
        del base_hyperparams["model"]
        del base_hyperparams["dataset"]
        base_hyperparams.pop("training_time", None)  # for old experiments

        # First, we reconstruct the training times
        if issubclass(MODEL_REGISTRY[model_name], TrainConfig):
            training_fractions = [1 / 81, 1 / 27] + [i / 9 for i in range(1, 10)]
        else:
            training_fractions = [0]

        assert all(
            len(job.metrics["training_time"]) == len(training_fractions) for job in jobs
        ), "Not all jobs provide sufficiently many models."

        # Then, we iterate over the Hyperband training times
        if len(training_fractions) == 1:
            training_fraction_indices = [0]
        else:
            training_fraction_indices = [0, 1, 2, 4, 10]

        # Then, we iterate over all training times, construct the hyperparameters and collect
        # the performane metrics
        for i in training_fraction_indices:
            hyperparams = {**base_hyperparams, "training_fraction": training_fractions[i]}
            # Update hyperparameters to remove model name prefix
            hyperparams = {
                (key[len(model_name) + 1 :] if key.startswith(model_name) else key): value
                for key, value in hyperparams.items()
            }
            model_config = get_model_config(model_name, **hyperparams)
            config = Config(model_config, get_dataset_config(ref_job.hyperparameters["dataset"]))

            # Get the indices of the models that should be used to derive the performance
            if validation_metric is None or len(training_fractions) == 1:
                # If the model does not require training, or we don't look at the validation
                # performance, we just choose the current index
                choices = [i] * len(jobs)
            else:
                # Otherwise, we get the minimum value for the metric up to this point in time
                metric = {"loss": "val_loss", "ncrps": "val_mean_weighted_quantile_loss"}[
                    validation_metric
                ]
                choices = [np.argmin(job.metrics[metric][: i + 1]) for job in jobs]

            # Filter the performance metrics
            metric_list = [
                {
                    m: _extract_metric(job.metrics[m], choices[k])
                    for m in Performance.metrics()
                    if m != "training_time"
                }
                for k, job in enumerate(jobs)
            ]
            metrics = {
                m: _accumulate_values([item[m] for item in metric_list])
                for m in Performance.metrics()
                if m != "training_time"
            }

            # Create the performance object
            performance = Performance(
                **metrics,
                training_time=Metric(training_fractions[i] * config.dataset.max_training_time, 0)
            )

            # Initialize the info object
            runs.append(JobInfo(config, performance, jobs, choices))

    return runs


# -------------------------------------------------------------------------------------------------


def _accumulate_values(values: List[float]) -> Metric:
    return Metric(np.mean(values), np.std(values))


def _extract_metric(values: np.ndarray, index: int) -> Any:
    if index < len(values):
        return values[index].item()
    return math.nan
