from typing import List, SupportsFloat, Union
from collections import deque

import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer

class ReduceLRWhenUnstable(LRScheduler):
    """
    Learning rate scheduler that reduces the learning rate when the sign of the loss changes a lot.
    """

    def __init__(
        self,
        optimizer: Optimizer,
        factor=0.1,
        patience=100,
        min_fraction=0.25,
        cooldown=0,
        min_lr: Union[List[float], float] = 0,
        eps=1e-8,
    ):
        if factor >= 1.0:
            raise ValueError("Factor should be < 1.0.")
        self.factor = factor

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
        self.optimizer = optimizer

        self.min_lr = min_lr

        self.patience = patience
        self.cooldown = cooldown
        self.cooldown_counter = 0

        self.last_loss = torch.inf
        self.loss_increases = deque(maxlen=patience)
        self.sign_flips = deque(maxlen=patience)
        self.min_fraction = min_fraction

        self.eps = eps
        self.last_epoch = 0
        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

    def step(self, metrics: SupportsFloat):
        current_loss = float(metrics)
        loss_diff = current_loss - self.last_loss
        self.last_loss = current_loss

        if self.in_cooldown:
            self.cooldown_counter -= 1
            return

        self.loss_increases.append(loss_diff >= 0)
        if len(self.loss_increases) >= 2:
            self.sign_flips.append(self.loss_increases[-1] != self.loss_increases[-2])

        if len(self.loss_increases) == self.patience:
            # if sum(self.loss_increases) > self.min_fraction * self.patience:
            if sum(self.sign_flips) > self.min_fraction * self.patience:
                self._reduce_lr()
                self.cooldown_counter = self.cooldown
                self.loss_increases.clear()

        self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

    def _reduce_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group["lr"])
            new_lr = max(old_lr * self.factor, self.min_lr)
            if old_lr - new_lr > self.eps:
                param_group["lr"] = new_lr

    @property
    def in_cooldown(self):
        return self.cooldown_counter > 0