from enum import Enum
from typing import Any

import torch
from pydantic import BaseModel, model_validator

from src.trainer.sst.utils import EMA
from src.utils.logging_utils import get_logger

logger = get_logger(__name__, level="DEBUG")


class ValueStrategy(BaseModel):
    class Name(str, Enum):
        TRAIN_LOSS = "train_loss"
        EX_LOSS_MEAN = "ex_loss_mean"
        EX_LOSS_MEDIAN = "ex_loss_median"
        EX_LOSS_MAX = "ex_loss_max"
        EX_LOSS_MIN = "ex_loss_min"
        EX_LOSS_P = "ex_loss_p"  # percentile
        PPL_LOSS_MEAN = "ppl_loss_mean"
        PPL_LOSS_MEDIAN = "ppl_loss_median"
        PPL_LOSS_MAX = "ppl_loss_max"
        PPL_LOSS_MIN = "ppl_loss_min"
        PPL_LOSS_P = "ppl_loss_p"  # percentile
        INITIAL_VAL = "initial_val"
        REFERENCE = "reference"
        CONST_VAL = "const_val"
        NO = "no"

    name: Name
    arg: Any | None = None

    @model_validator(mode="after")
    def validate(self):
        if self.name == self.Name.EX_LOSS_P:
            assert self.arg is not None and isinstance(self.arg, int)
            assert 0 <= self.arg <= 100, f"Invalid percentile value: {self.arg}"
        elif self.name == self.Name.CONST_VAL:
            assert self.arg is not None and isinstance(self.arg, (int, float)), f"Invalid constant value: {self.arg}"
            self.arg = float(self.arg)
            assert self.arg >= 0, f"Invalid constant value: {self.arg}"
        else:
            assert self.arg is None, f"{self.name} takes no argument"

        return self


class ValueWHistory(BaseModel):
    _name: str
    initial: torch.Tensor | None = None
    current: torch.Tensor | None = None
    min: torch.Tensor | None = None
    max: torch.Tensor | None = None
    _ema: EMA | None = None

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, *, name: str, ema: EMA | None = None, **kwargs):
        super().__init__(**kwargs)
        self._name = name
        self._ema = ema

    @property
    def is_set(self):
        for _, v in self.model_dump().items():
            if v is None:
                return False
        return True

    def _update_current(self, new_value: torch.Tensor):
        assert self.is_set

        new_value = new_value.clone()

        if self._ema is not None:
            self._ema.update(new_value)
            new_value = self._ema.get_ema()

        old_value = self.current.item()
        self.current = new_value.clone()
        logger.debug(f"Updated {self._name} current value: {old_value}->{self.current}")

    def initialize(self, *, init_value: torch.Tensor, min_value: torch.Tensor, max_value: torch.Tensor):
        assert not self.is_set

        self.initial = init_value.clone()
        self.current = init_value.clone()
        self.min = min_value.clone()
        self.max = max_value.clone()
        logger.debug(f"Initialized {self._name} with value: {self.current} (min: {self.min}, max: {self.max})")

    def update_bounds(
        self,
        *,
        min_value: torch.Tensor,
        max_value: torch.Tensor,
    ):
        self.min = min_value.clone()
        self.max = max_value.clone()

        logger.debug(f"Updated {self._name} values: (min: {self.min}, max: {self.max})")

    def dump(self):
        return {f"sst_{self._name}_{k}": v.item() for k, v in self.model_dump().items()}


class ThresholdUpdater:
    class Config(BaseModel):
        reference_update_strategy: ValueStrategy
        """The strategy to use to compute the reference value"""

        reference_max_strategy: ValueStrategy
        """The strategy to use to compute the reference max value"""

        reference_min_strategy: ValueStrategy
        """The strategy to use to compute the reference min value."""

        threshold_init_strategy: ValueStrategy
        """The strategy to use to compute the threshold initial value."""

        threshold_max_startegy: ValueStrategy
        """The strategy to use to compute the threshold max value."""

        threshold_min_strategy: ValueStrategy
        """The strategy to use to compute the threshold min value."""

        threshold_update_strategy: ValueStrategy
        """The strategy when reference = no"""

        use_ema_for_ref: bool = False
        """Whether to use an EMA for the reference value"""

        @model_validator(mode="after")
        def validate(self):
            return self

    def __init__(self, **kwargs):
        self.config = self.Config(**kwargs)
        self._step_count = 0
        self._threshold = ValueWHistory(name="threshold")
        self._reference = ValueWHistory(name="reference", ema=EMA() if self.config.use_ema_for_ref else None)

    @property
    def threshold(self):
        return self._threshold.current

    def step(
        self,
        *,
        curr_train_loss,
        ex_losses,
        **kwargs,
    ):
        if self._step_count == 0:
            # First step, just initializes values
            self._init_values(curr_train_loss=curr_train_loss, ex_losses=ex_losses)
            self._step_count += 1
            return False

        self._update_bounds(curr_train_loss=curr_train_loss, ex_losses=ex_losses)

        self._threshold.is_set
        self._step_count += 1

        if self.config.threshold_update_strategy.name == ValueStrategy.Name.REFERENCE:
            # The threshold follows the reference value
            assert self._reference.is_set
            # Update reference value
            ref_value = self._get_reference_value(curr_train_loss=curr_train_loss, ex_losses=ex_losses)
            self._reference._update_current(new_value=ref_value)

            # Compute how much the reference value has changed wrt initial value
            change_ratio = (self._reference.initial - self._reference.current) / self._reference.initial  # type: ignore
            assert -1 <= change_ratio <= 1

            # Compute the new threshold value
            new_threshold_value = self._threshold.initial + (
                (self._threshold.max - self._threshold.initial) * change_ratio
            )  # type: ignore

        elif self.config.threshold_update_strategy.name == ValueStrategy.Name.EX_LOSS_P:
            # The threshold is a percentile of the ex_losses
            new_threshold_value = torch.quantile(ex_losses, self.config.threshold_update_strategy.arg / 100)  # type: ignore
        else:
            raise NotImplementedError(f"Invalid threshold value strategy: {self.config.threshold_update_strategy}")

        if new_threshold_value < self._threshold.min:
            new_threshold_value = self._threshold.min
            logger.warning("Threshold value is below the minimum. Clipping to min value")
        elif new_threshold_value > self._threshold.max:
            new_threshold_value = self._threshold.max
            logger.warning("Threshold value is above the maximum. Clipping to max value")

        # Check if the threshold value has changed, to avoid tiny changes
        if torch.isclose(self._threshold.current, new_threshold_value):  # type: ignore
            logger.debug("Threshold did not change. Skipping update")
            return False

        # Update the threshold with the same ratio
        self._threshold._update_current(new_value=new_threshold_value)

        return True

    def _init_values(self, *, ex_losses, curr_train_loss):
        assert not self._reference.is_set and not self._threshold.is_set

        if self.config.threshold_update_strategy.name == ValueStrategy.Name.REFERENCE:
            # Initialize reference value
            value = self._get_reference_value(curr_train_loss=curr_train_loss, ex_losses=ex_losses)
            min_ref_value, max_ref_value = self._get_reference_bounds(ex_losses=ex_losses)
            self._reference.initialize(
                init_value=value,
                min_value=min_ref_value,
                max_value=max_ref_value,
            )

        # Initialize threshold value
        match self.config.threshold_init_strategy.name:
            case ValueStrategy.Name.REFERENCE:
                assert self._reference.is_set
                threshold_init_value = self._reference.current.clone()
            case ValueStrategy.Name.EX_LOSS_P:
                assert self.config.threshold_update_strategy.arg is not None
                threshold_init_value = torch.quantile(ex_losses, self.config.threshold_update_strategy.arg / 100)
            case ValueStrategy.Name.EX_LOSS_MIN:
                threshold_init_value = ex_losses.min()
            case _:
                raise NotImplementedError(f"Invalid threshold init strategy: {self.config.threshold_update_strategy}")

        min_threshold, max_threshold = self._get_threshold_bounds(ex_losses=ex_losses)
        self._threshold.initialize(
            init_value=threshold_init_value,  # type: ignore
            min_value=min_threshold,
            max_value=max_threshold,
        )

    def _update_bounds(self, *, ex_losses, curr_train_loss):
        assert self._threshold.is_set
        if self.config.threshold_update_strategy.name == ValueStrategy.Name.REFERENCE:
            assert self._reference.is_set
            # The min or max might have changed, so we need to update them
            min_ref_value, max_ref_value = self._get_reference_bounds(ex_losses=ex_losses)

            self._reference.update_bounds(
                min_value=min_ref_value,
                max_value=max_ref_value,
            )

        # Update the threshold bounds
        min_threshold, max_threshold = self._get_threshold_bounds(ex_losses=ex_losses)
        self._threshold.update_bounds(
            min_value=min_threshold,
            max_value=max_threshold,
        )

    def _get_reference_value(self, *, curr_train_loss, ex_losses):
        if self.config.reference_update_strategy.name == ValueStrategy.Name.EX_LOSS_MEAN:
            return ex_losses.mean()
        if self.config.reference_update_strategy.name == ValueStrategy.Name.EX_LOSS_MEDIAN:
            return ex_losses.median()
        if self.config.reference_update_strategy.name == ValueStrategy.Name.TRAIN_LOSS:
            return curr_train_loss
        raise ValueError(f"Invalid threshold reference: {self.config.reference_update_strategy}")

    def _get_reference_bounds(self, *, ex_losses):
        return self._get_bounds(
            name="reference",
            ex_losses=ex_losses,
            value=self._reference,
            max_strategy=self.config.reference_max_strategy,
            min_strategy=self.config.reference_min_strategy,
        )

    def _get_threshold_bounds(self, *, ex_losses):
        return self._get_bounds(
            name="threshold",
            ex_losses=ex_losses,
            value=self._threshold,
            max_strategy=self.config.threshold_max_startegy,
            min_strategy=self.config.threshold_min_strategy,
        )

    def _get_bounds(
        self,
        *,
        name: str,
        ex_losses,
        value: ValueWHistory,
        max_strategy: ValueStrategy,
        min_strategy: ValueStrategy,
    ):
        # Max
        match max_strategy.name:
            case ValueStrategy.Name.EX_LOSS_P:
                max_value = torch.quantile(ex_losses, max_strategy.arg / 100)
            case ValueStrategy.Name.EX_LOSS_MAX:
                max_value = ex_losses.max()
            case _:
                raise ValueError(f"Invalid {name} max strategy: {max_strategy}")

        # Min
        match min_strategy.name:
            case ValueStrategy.Name.CONST_VAL:
                min_value = torch.tensor(min_strategy.arg, dtype=ex_losses.dtype)
            case ValueStrategy.Name.INITIAL_VAL:
                min_value = value.initial.clone()
            case ValueStrategy.Name.EX_LOSS_MIN:
                min_value = ex_losses.min()
            case ValueStrategy.Name.REFERENCE:
                assert self._reference.is_set
                min_value = self._reference.current.clone()
            case _:
                raise ValueError(f"Invalid {name} min strategy: {min_strategy}")

        assert min_value < max_value

        logger.info(f"Computed {name} bounds: min={min_value} ({min_strategy}), max={max_value} ({max_strategy})")

        return min_value, max_value

    def _get_state(self):
        state = {}
        if self._reference.is_set:
            state.update(self._reference.dump())

        if self._threshold.is_set:
            state.update(self._threshold.dump())

        return state
