from __future__ import annotations

import pickle
from pathlib import Path
from typing import Any

import numpy as np
from matplotlib import pyplot as plt


class History:
    """
    Class representing the history of a training process. This class stores
    losses and other metrics, allowing for different logging frequencies
    by storing both x and y values for each metric. It also offers methods
    to visualize the history using matplotlib with customizable labels.

    Where x-values are not explicitly provided it assumes sequential
    integer indices.
    """

    def __init__(
        self,
        history: dict[str, list[float] | list[tuple[float, float]]] | None = None,
        metrics: list[str] | None = None,
    ) -> None:
        """
        Initializes the History object.

        Args:
            history (dict[str, list[list[float] | tuple[float, float]]] | None):
                An optional dictionary to initialize the history.
        """
        self.history: dict[str, list[tuple[float, float]]] = {}
        if metrics is not None:
            for metric in metrics:
                self.history[metric] = []
        if history is not None:
            for keys in history:
                self.history[keys] = []
            self.extend(history)

    def __getitem__(self, key: str) -> list[tuple[float, float]]:
        """
        Retrieves the history for a given metric key.

        Args:
            key (str): The name of the metric.

        Returns:
            list[tuple[float, float]]: A list of (x_value, y_value) tuples for
            the metric.
        """
        return self.history[key]

    def extend(
        self,
        history: History | dict[str, list[float] | list[tuple[float, float]]],
    ) -> None:
        """
        Extend the history with another history object or dictionary.
        This only appends values for keys that are already present in this history.

        Args:
            history (History | dict[str, list[float] | list[tuple[Any, float]]]):
                The history to add to this history.
        """
        history_data = history.history if isinstance(history, History) else history

        for key, values in history_data.items():
            if key in self.history:
                if len(values) == 0:
                    continue

                if all(isinstance(v, float) for v in values):
                    current_len = len(self.history[key])
                    self.history[key].extend(
                        [(current_len + i, v) for i, v in enumerate(values)]  # type: ignore
                    )
                elif all(isinstance(v, tuple) and len(v) == 2 for v in values):
                    self.history[key].extend(values)  # type: ignore
                else:
                    raise ValueError(
                        f"Invalid values format for key '{key}' during extend. "
                        "Expected list[float] or list[tuple[Any, float]]."
                    )

    def append(self, key: str, value: float, x_value: float | None = None) -> None:
        """
        Append a single data point to a specific metric's history. If `x_value`
        is None, it assumes the next sequential integer index.

        Args:
            key (str): The name of the metric.
            value (float): The y-axis value (e.g., loss, accuracy).
            x_value (Any | None): The x-axis value (e.g., epoch, iteration number).
                                 If None, the next sequential integer index is
                                 used.
        """
        if key not in self.history:
            self.history[key] = []

        if x_value is None:
            x_value = 0.0
            if self.history[key]:
                last_x = self.history[key][-1][0]
                x_value = last_x + 1

        self.history[key].append((x_value, value))

    def append_misc(self, key: str, value: Any) -> None:  # noqa: ANN401
        """
        Append a single value to a specific metric's history without an x-value.
        This is useful for metrics that do not require an x-axis value.

        Args:
            key (str): The name of the metric.
            value (Any): The value to append.
        """
        if key not in self.history:
            self.history[key] = []

        self.history[key].append(value)  # type: ignore

    def append_from_dict(
        self, history_dict: dict[str, float | tuple[float, float]]
    ) -> None:
        """
        Append a history-like dictionary with only one value per key to this
        history. This only keeps the values of the keys that are already in
        this history and assumes sequential x-values.

        Args:
            history_dict (dict[str, float | tuple[float, float]]): The
            history-like dictionary to add.
        """
        for key, value in history_dict.items():
            if key in self.history:
                if isinstance(value, float):
                    self.append(key, value)
                elif isinstance(value, tuple):
                    self.append(key, value[1], value[0])

    def save(self, filepath: str | Path) -> None:
        """
        Save the history to a numpy file.

        Args:
            filepath (str | Path): The path to the file where the history
                                   will be saved.
        """
        with open(filepath, "wb") as f:
            pickle.dump(self.history, f)

    @staticmethod
    def load(filepath: str | Path) -> History:
        """
        Load a history from a numpy file.

        Args:
            filepath (str | Path): The path to the file from which the history
                                   will be loaded.

        Returns:
            History: The loaded History object.
        """
        with open(filepath, "rb") as f:
            history_data = pickle.load(f)
        return History(history=history_data)

    def visualize(
        self,
        metrics: list[str] | None = None,
        x_label: str = "X-axis (Epoch/Iteration)",
        y_label: str = "Value",
        title: str = "Training History",
        ax: plt.Axes | None = None,
    ) -> plt.Figure:
        """
        Create matplotlib figures for the specified or all metrics in this history.

        Args:
            metrics (list[str] | None): A list of metric names to plot. If None,
                all metrics will be plotted.
            x_label (str): Label for the x-axis.
            y_label (str): Label for the y-axis.
            title (str): Title of the plot.
            ax (plt.Axes | None): Axis to plot on. If None, a new figure+axis is created.

        Returns:
            plt.Figure: The figure with the plots.
        """
        if ax is None:
            fig, ax = plt.subplots(figsize=(10, 6))
        else:
            fig = ax.figure

        for key, values in self.history.items():
            if metrics is not None and key not in metrics:
                continue

            if not values:
                continue

            vals = np.array(values)

            ax.plot(vals[:, 0], vals[:, 1], label=key)

        ax.set_title(title)
        ax.set_ylabel(y_label)
        ax.set_xlabel(x_label)
        ax.legend()
        ax.grid(True, linestyle="--", alpha=0.7)

        return fig
