from __future__ import annotations

import json
import random
import time
from functools import wraps
from typing import TYPE_CHECKING, Any

import torch
from jaxtyping import Float
from scipy import linalg
from torch import Tensor

if TYPE_CHECKING:
    from collections.abc import Callable


class EarlyStopping:
    """Class for early stopping during model training.

    This class provides functionality for early stopping based on a specified monitor value.
    It keeps track of the best state of the model based on the monitor value and can judge whether
    to stop the training based on the direction of improvement (minimize or maximize).

    Parameters:
        monitor : str, optional
            The name of the monitor value to track (default is "loss").
        direction : str, optional
            The direction of improvement for the monitor value. Must be either "min" or "max" (default is "min").

    Attributes:
        monitor : str
            The name of the monitor value being tracked.
        direction : str
            The direction of improvement for the monitor value.
        best_state : Any
            The best state of the model based on the monitor value.
        monitor_values : dict[str, float]
            A dictionary to store the monitor values.

    Methods:
        judge(values: dict[str, float]) -> bool
            Judge whether to stop the training based on the monitor values.
        update(values: dict[str, float]) -> None
            Update the monitor values.
        get_value() -> float
            Get the current value of the monitor.

    Examples:
        >>> early_stopping = EarlyStopping(monitor="loss", direction="min")
        >>> early_stopping.judge({"loss": 0.5})
        True
        >>> early_stopping.update({"loss": 0.3})
        >>> early_stopping.get_value()
        0.3
    """

    def __init__(self, monitor: str = "loss", direction: str = "min") -> None:
        """Initialize the Monitor class.

        Args:
            monitor (str): The name of the monitored metric. Defaults to "loss".
            direction (str): The direction of optimization. Must be either "min" or "max". Defaults to "min".

        Raises:
            ValueError: If the direction is not "min" or "max".
        """
        self.monitor = monitor
        self.direction = direction
        self.best_state = None
        if direction == "min":
            self.monitor_values = {self.monitor: float("inf")}
        elif direction == "max":
            self.monitor_values = {self.monitor: -float("inf")}
        else:
            error_message = "args: [direction] must be min or max"
            raise ValueError(error_message)

    def judge(self, values: dict[str, float]) -> bool:
        """Judge whether to stop the training based on the monitor values.

        Args:
            values (dict[str, float]): The monitor values.

        Returns:
            bool: True if the training should be stopped, False otherwise.
        """
        return (self.direction == "min" and self.monitor_values[self.monitor] > values[self.monitor]) or (
            self.direction == "max" and self.monitor_values[self.monitor] < values[self.monitor]
        )

    def update(self, values: dict[str, float]) -> None:
        """Update the monitor values.

        Args:
            values (dict[str, float]): The monitor values.

        Returns:
            None
        """
        self.monitor_values[self.monitor] = values[self.monitor]

    def get_value(self) -> float:
        """Get the current value of the monitor.

        Returns:
            float: The current value of the monitor.
        """
        return self.monitor_values[self.monitor]


def write_json(file_name: str, body: dict) -> None:
    with open(file_name, mode="w") as f:
        json.dump(body, f, indent=4)


def load_json(file_name: str) -> dict:
    out = None
    with open(file_name, mode="r") as f:
        out = json.load(f)
    return out


def stop_watch(func: Callable) -> Any:  # noqa: ANN401
    @wraps(func)
    def wrapper(*args, **kargs) -> Any:  # noqa: ANN002, ANN003, ANN401
        start = time.time()
        result = func(*args, **kargs)
        elapsed_time = time.time() - start
        print(f"{func.__name__} took {elapsed_time} seconds.")
        return result

    return wrapper


def sqrtm(M: Float[Tensor, "feature feature"]) -> Float[Tensor, "feature feature"]:
    return torch.tensor(linalg.sqrtm(M.cpu().numpy())).to(device=M.device)


def sample_from_inverse_cdf(inverse_cdf: Callable[[float], float]) -> float:
    uniform_sample = random.random()
    return inverse_cdf(uniform_sample)


def samples_from_inverse_cdf(
    inverse_cdf: Callable[[Float[Tensor, "batch_size"]], Float[Tensor, "batch_size"]], batch_size: int
) -> Float[Tensor, "batch_size"]:
    rand = torch.rand(batch_size)
    return inverse_cdf(rand)


def simulation(
    x: Float[Tensor, "batch_drift feature"],
    tau: float,
    n_steps: int,
    epsilon: float,
    drift_func: Callable[
        [Float[Tensor, "batch_drift feature"], Float[Tensor, "batch_drift"]], Float[Tensor, "batch_drift feature"]
    ],
) -> Float[Tensor, "batch_drift n_steps + 1 feature"]:
    with torch.no_grad():
        t = torch.full(size=(x.shape[0],), fill_value=0.0, device=x.device)
        dt = tau / n_steps

        trajectory: list[Float[Tensor, "feature"]] = [x]

        for _ in range(n_steps):
            noise = torch.randn_like(x, device=x.device)
            x = x + dt * drift_func(x, t) + (epsilon * dt) ** 0.5 * noise
            t += dt
            trajectory.append(x)

        return torch.stack(trajectory, dim=0)
