"""Misc. optimizer implementations."""
from functools import partial

import torch
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler

def get_schedule_fn(scheduler, num_training_steps=None, warmup_steps=None, dim_embed=None):
    """Returns a callable scheduler_fn(optimizer).
    """
    if scheduler == "cosine-decay":
        scheduler_fn = partial(
            torch.optim.lr_scheduler.CosineAnnealingLR,
            T_max=num_training_steps,
            eta_min=0.0,
        )
    elif scheduler == "one-cycle":  # this is a simplified one-cycle
        scheduler_fn = partial(
            get_one_cycle,
            num_training_steps=num_training_steps,
        )
    elif scheduler == "transformer":
        scheduler_fn = partial(
            TransformerScheduler,
            dim_embed=dim_embed,
            warmup_steps=warmup_steps
        )
    else:
        raise ValueError(f"Invalid schedule {scheduler} given.")
    return scheduler_fn

def get_one_cycle(optimizer, num_training_steps):
    """Simple single-cycle scheduler. Not including paper/fastai three-phase things or asymmetry."""

    def lr_lambda(current_step):
        if current_step < num_training_steps / 2:
            return float(current_step / (num_training_steps / 2))
        else:
            return float(2 - current_step / (num_training_steps / 2))

    return LambdaLR(optimizer, lr_lambda, -1)

class TransformerScheduler(_LRScheduler):
    def __init__(self, optimizer, dim_embed, warmup_steps, last_epoch=-1, verbose=False):
        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        lr = self.dim_embed**(-0.5) * min(self._step_count**(-0.5), self._step_count * self.warmup_steps**(-1.5))
        return [lr] * self.num_param_groups
