import math
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

class NullScheduler(_LRScheduler):
    """
    A learning rate scheduler that does nothing.
    The learning rate remains constant throughout training.
    """
    def __init__(self, optimizer, last_epoch=-1, verbose=False):
        """
        Initializes the scheduler.
        Args:
            optimizer (Optimizer): The optimizer to wrap.
            last_epoch (int): The index of the last epoch. Defaults to -1.
            verbose (bool): If True, prints a message to stdout for each update.
        """
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        """
        This method is called by the base class step() method.
        It should return a list of new learning rates, but we just
        return the original base learning rates.
        """
        # self.base_lrs is a list of the initial learning rates, one for
        # each parameter group in the optimizer.
        return self.base_lrs

class CosineAnnealingWarmupRestarts(_LRScheduler):
    """
    Cosine annealing with warmup + restarts.

    Floors:
      - warmup_floor:
          * default = min_lr
          * if last active cycle AND cycle > 0 => final_fixed_lr
      - cosine_floor:
          * last active cycle => final_fixed_lr
          * otherwise => min_lr

    After max_cycles complete, LR is frozen at final_fixed_lr.
    """

    def __init__(
        self,
        optimizer: Optimizer,
        first_cycle_steps: int,
        cycle_mult: float = 1.0,
        max_lr: float = 0.1,
        min_lr: float = 0.001,
        warmup_steps: int = 0,
        gamma: float = 1.0,
        last_epoch: int = -1,
        max_cycles: int | None = None,
        final_fixed_lr: float | None = None,
    ):
        assert warmup_steps < first_cycle_steps, "warmup_steps must be < first_cycle_steps"

        self.first_cycle_steps = int(first_cycle_steps)
        self.cycle_mult = float(cycle_mult)
        self.base_max_lr = float(max_lr)
        self.max_lr = float(max_lr)
        self.min_lr = float(min_lr)
        self.warmup_steps = int(warmup_steps)
        self.gamma = float(gamma)

        # cycle bookkeeping
        self.cur_cycle_steps = int(first_cycle_steps)
        self.cycle = 0
        self.step_in_cycle = last_epoch

        # restart cap / freezing
        self.max_cycles = None if max_cycles is None else int(max_cycles)
        self._frozen = False

        # final LR after freezing (and cosine floor in last active cycle)
        self.final_fixed_lr = float(final_fixed_lr) if final_fixed_lr is not None else float(min_lr)

        super().__init__(optimizer, last_epoch)

        # Initialize to min_lr (first cycle warmup starts from min_lr even if it's also the last)
        for pg in self.optimizer.param_groups:
            pg["lr"] = self.min_lr

    # ---------- helpers ----------

    def is_batch_based(self):
        return True

    def _in_last_active_cycle(self) -> bool:
        return (self.max_cycles is not None) and (self.cycle >= self.max_cycles - 1) and (not self._frozen)

    def _floors(self):
        """
        Returns (warmup_floor, cosine_floor) for the current cycle.
        - Special case: if this is the last active cycle *and* cycle==0 (i.e., max_cycles==1),
          warmup starts from min_lr (not final_fixed_lr), but cosine decays to final_fixed_lr.
        """
        last_active = self._in_last_active_cycle()

        # warmup floor
        if last_active and self.cycle > 0:
            warmup_floor = self.final_fixed_lr
        else:
            warmup_floor = self.min_lr

        # cosine floor
        cosine_floor = self.final_fixed_lr if last_active else self.min_lr
        return warmup_floor, cosine_floor

    def _frozen_lrs(self):
        return [self.final_fixed_lr for _ in self.optimizer.param_groups]

    # ---------- LR computation ----------

    def get_lr(self):
        if self._frozen:
            return self._frozen_lrs()

        warmup_floor, cosine_floor = self._floors()
        peak = self.max_lr

        if self.step_in_cycle == -1:
            # before first step of a cycle: sit at the warmup floor
            return [warmup_floor for _ in self.optimizer.param_groups]

        if self.warmup_steps > 0 and self.step_in_cycle < self.warmup_steps:
            # linear warmup: warmup_floor -> peak
            alpha = self.step_in_cycle / self.warmup_steps
            lr = warmup_floor + (peak - warmup_floor) * alpha
            return [lr for _ in self.optimizer.param_groups]

        # cosine: peak -> cosine_floor
        denom = max(1, self.cur_cycle_steps - max(self.warmup_steps, 0))
        progress = (self.step_in_cycle - self.warmup_steps) / denom
        lr = cosine_floor + (peak - cosine_floor) * (1 + math.cos(math.pi * progress)) / 2.0
        lr = max(lr, cosine_floor)
        return [lr for _ in self.optimizer.param_groups]

    # ---------- step / cycle logic ----------

    def step(self, epoch: int | None = None):
        if self._frozen:
            self.last_epoch = (self.last_epoch + 1) if epoch is None else math.floor(epoch)
            for pg, lr in zip(self.optimizer.param_groups, self._frozen_lrs()):
                pg["lr"] = lr
            return

        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle += 1

            if self.step_in_cycle >= self.cur_cycle_steps:
                # finished a cycle
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps

                if self.max_cycles is not None and self.cycle >= self.max_cycles:
                    self._frozen = True
                else:
                    # next cycle length (if mult>1)
                    if self.cycle > 0 and self.cycle_mult != 1.0:
                        body = max(self.cur_cycle_steps - self.warmup_steps, 1)
                        self.cur_cycle_steps = int(body * self.cycle_mult) + self.warmup_steps
                    # new peak for the new cycle
                    self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)
                    # snap LR to the new cycle's warmup floor
                    warmup_floor, _ = self._floors()
                    for pg in self.optimizer.param_groups:
                        pg["lr"] = warmup_floor
        else:
            # epoch-based jump
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.0:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                    self.cur_cycle_steps = self.first_cycle_steps
                else:
                    n = int(math.log((epoch / self.first_cycle_steps) * (self.cycle_mult - 1.0) + 1.0,
                                     self.cycle_mult))
                    self.cycle = n
                    self.step_in_cycle = epoch - int(
                        self.first_cycle_steps * (self.cycle_mult**n - 1) / (self.cycle_mult - 1)
                    )
                    self.cur_cycle_steps = int(self.first_cycle_steps * (self.cycle_mult**n))
            else:
                self.cycle = 0
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

            if self.max_cycles is not None and self.cycle >= self.max_cycles:
                self._frozen = True

            if not self._frozen:
                self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)

        self.last_epoch = math.floor(epoch)

        # apply LR for this step
        lrs = self.get_lr()
        for pg, lr in zip(self.optimizer.param_groups, lrs):
            pg["lr"] = lr