import math

import pytorch_lightning as pl
from pytorch_lightning import Callback


class WeightDecayScheduler(Callback):

    def __init__(self, schedule_weight_decay: bool, schedule_type: str, scale: float):
        super().__init__()
        self.schedule_weight_decay = schedule_weight_decay

        self.schedule_type = schedule_type

        self.decay = scale

        self._step_count = 0

    @staticmethod
    def get_scheduler(schedule_type, num_warmup_steps, decay_factor, num_training_steps):
        def fn_scheduler(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            elif schedule_type == 'linear':
                return (decay_factor + (1 - decay_factor) *
                        max(0.0, float(num_training_steps - num_warmup_steps - current_step) / float(
                            max(1, num_training_steps - num_warmup_steps))))
            elif schedule_type == 'cosine':
                return (decay_factor + (1 - decay_factor) *
                        max(0.0, (1 + math.cos(math.pi * (current_step - num_warmup_steps) / float(
                            max(1, num_training_steps - num_warmup_steps)))) / 2))
            elif schedule_type == 'const':
                return 1.0

        return fn_scheduler

    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        self.num_training_steps = trainer.max_steps

        self.weight_decay = []
        for optim in trainer.optimizers:
            for group_idx, group in enumerate(optim.param_groups):
                if 'weight_decay' in group:
                    self.weight_decay.append(group['weight_decay'])

        num_warmup_steps = 0

        self.scheduler = self.get_scheduler(self.schedule_type, num_warmup_steps, self.decay, self.num_training_steps)

    def on_before_optimizer_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer):

        if self.schedule_weight_decay:
            stats = {}
            for group_idx, group in enumerate(optimizer.param_groups):
                if 'weight_decay' in group:
                    group['weight_decay'] = self.weight_decay[group_idx] * self.scheduler(self._step_count)
                    stats[f"weight_decay/rank_{trainer.local_rank}/group_{group_idx}"] = group['weight_decay']

            if trainer.loggers is not None:
                for logger in trainer.loggers:
                    logger.log_metrics(stats, step=trainer.global_step)
            self._step_count += 1
