import logging
from typing import Any, Callable, Dict

import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch import Tensor
import math
from typing_extensions import override

log = logging.getLogger(__name__)


class ReduceHParamsOnPlateau(Callback):

    mode_dict = {"min": torch.lt, "max": torch.gt}
    order_dict = {"min": "<", "max": ">"}

    def __init__(
            self,
            param_name: str,
            monitor: str,
            decay: float = 1.,
            min_value: float = -math.inf,
            max_value: float = math.inf,
            patience: int = 3,
            mode: str = "min",
            strict: bool = True,
            check_on_train_epoch_end=False,
            run_every='epoch'
    ):
        super().__init__()
        self.monitor = monitor
        self.param_name = param_name
        self.decay = decay
        self.min_value = min_value
        self.max_value = max_value
        self.patience = patience
        self.mode = mode
        self.strict = strict
        self.wait_count = 0
        self.stopped_epoch = 0
        self.run_every = run_every

        if self.mode not in self.mode_dict:
            raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")

        torch_inf = torch.tensor(torch.inf)
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
        self._check_on_train_epoch_end = check_on_train_epoch_end

    @property
    @override
    def state_key(self) -> str:
        """Unique identifier for the callback state."""
        return self._generate_state_key(monitor=self.monitor, mode=self.mode)

    def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
        """Check if the monitored metric is present in the logs."""
        monitor_val = logs.get(self.monitor)

        if monitor_val is None:
            error_msg = (
                f"Change hyperparams on plateau conditioned on metric `{self.monitor}` which is not available."
                f" Pass in or modify your `ChangeHParamsOnPlateau` callback to use any of the following:"
                f' `{"`, `".join(list(logs.keys()))}`'
            )

            if self.strict:
                raise RuntimeError(error_msg)

            return False

        return True

    @property
    def monitor_op(self) -> Callable:
        """Get the comparison operation based on the mode."""
        return self.mode_dict[self.mode]

    @override
    def state_dict(self) -> Dict[str, Any]:
        """Returns the state of the callback for checkpointing."""
        return {
            "param_name": self.param_name,
            "decay": self.decay,
            "min_value": self.min_value,
            "max_value": self.max_value,
            "wait_count": self.wait_count,
            "stopped_epoch": self.stopped_epoch,
            "best_score": self.best_score,
            "patience": self.patience,
        }

    @override
    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Restores the callback state from a checkpoint."""
        self.param_name = state_dict["param_name"]
        self.decay = state_dict["decay"]
        self.min_value = state_dict["min_value"]
        self.max_value = state_dict["max_value"]
        self.wait_count = state_dict["wait_count"]
        self.stopped_epoch = state_dict["stopped_epoch"]
        self.best_score = state_dict["best_score"]
        self.patience = state_dict["patience"]

    def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
        """Determine if the plateau check should be skipped."""
        from lightning.pytorch.trainer.states import TrainerFn

        return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

    @override
    def on_train_batch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs, batch: Any, batch_idx: int
    ) -> None:
        if self.run_every == 'step':
            self._run_plateau_check(trainer)

    @override
    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Check for plateau at the end of the training epoch."""
        if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
            return
        if self.run_every == 'epoch':
            self._run_plateau_check(trainer)

    @override
    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Check for plateau at the end of the validation epoch."""
        if self._check_on_train_epoch_end or self._should_skip_check(trainer):
            return
        if self.run_every == 'epoch':
            self._run_plateau_check(trainer)

    def _step(self, trainer: pl.Trainer):
        """Adjust the parameter using the step function."""
        old = getattr(trainer.model, self.param_name)
        if old > self.min_value:
            new_value = self.decay * old
            if isinstance(old, torch.nn.Parameter):
                old.data = new_value
            else:
                setattr(trainer.model, self.param_name, new_value)

    def _run_plateau_check(self, trainer: "pl.Trainer") -> None:
        """Evaluate if the condition to adjust the parameter is met."""
        logs = trainer.callback_metrics

        if trainer.fast_dev_run or not self._validate_condition_metric(logs):
            return

        current = logs[self.monitor].squeeze()
        should_step = self._evaluate_plateau_criteria(current)

        # Apply the adjustment across all DDP processes
        should_step = trainer.strategy.reduce_boolean_decision(should_step, all=False)
        if should_step:
            self._step(trainer)

    def _evaluate_plateau_criteria(self, current: Tensor) -> bool:
        """Determine if the plateau condition is met based on the current metric."""
        if self.monitor_op(current, self.best_score.to(current.device)):
            self.best_score = current
            self.wait_count = 0
            return False
        else:
            self.wait_count += 1
            if self.wait_count > self.patience:
                self.wait_count = 0
                return True
        return False
