"""Interface with optimizers."""

import re
import torch

from .additional_optimizers import FISTA, LARS, SGD_AGC, LBFGS, SAM, GradualWarmupScheduler
from .additional_optimizers import RestartingLineSearch, NonMonotoneLinesearch, WolfeGradientDescent
from .additional_optimizers import AdaptiveGradientClipping


def optim_interface(model, cfg_hyp):
    """Construct optimizer and scheduler objects."""
    optim_params = {k: v for k, v in cfg_hyp.optim.items() if k != "name"}

    if cfg_hyp.only_linear_layers_weight_decay and not cfg_hyp.optim.name == "GD-AGC":
        parameter_iterable = []
        for key, value in model.named_parameters():
            # this regex snippet is modified from https://github.com/benjs/nfnets_pytorch/blob/master/train.py
            if len(re.findall("(bias|gain)|skip_gain", key)) > 0:
                parameter_iterable += [{"params": [value], "weight_decay": 0.0}]
            else:
                parameter_iterable += [{"params": [value]}]
    else:
        parameter_iterable = model.parameters()

    if cfg_hyp.optim.name == "Gradient Descent":
        optim_params = {k: v for k, v in optim_params.items() if k != "line_search"}
        if cfg_hyp.optim.line_search == "none":
            optimizer = torch.optim.SGD(parameter_iterable, **optim_params)
        elif cfg_hyp.optim.line_search == "wolfe":
            optimizer = WolfeGradientDescent(parameter_iterable, **optim_params)
        elif cfg_hyp.optim.line_search == "non-monotone":
            optimizer = NonMonotoneLinesearch(parameter_iterable, **optim_params)
        elif cfg_hyp.optim.line_search == "restarting":
            optimizer = RestartingLineSearch(parameter_iterable, **optim_params)
        else:
            raise ValueError(f"Invalid linesearch {cfg_hyp.optim.line_search} defined.")
    elif cfg_hyp.optim.name == "Adaptive Gradient Descent":
        optimizer = AdaptiveGradientClipping(parameter_iterable, **optim_params)
    elif cfg_hyp.optim.name == "Adam":
        optimizer = torch.optim.AdamW(parameter_iterable, **optim_params)
    elif cfg_hyp.optim.name == "L-BFGS":
        optimizer = LBFGS(parameter_iterable, **optim_params)
    elif cfg_hyp.optim.name == "FISTA":
        optimizer = FISTA(parameter_iterable, **optim_params)
    elif cfg_hyp.optim.name == "GD-AGC":
        optimizer = SGD_AGC(model.named_parameters(), **optim_params)
        for group in optimizer.param_groups:
            if group["name"].startswith("linear"):
                group["clipping"] = None
            if cfg_hyp.only_linear_layers_weight_decay:
                if len(re.findall("stem.*(bias|gain)|conv.*(bias|gain)|skip_gain", group["name"])) > 0:
                    group["weight_decay"] = 0
    else:
        raise ValueError(f"Invalid optimizer {cfg_hyp.optim.name} provided.")

    if cfg_hyp.optim_modification.name == "none":
        optimizer_to_schedule = optimizer
    else:
        if cfg_hyp.optim_modification.name == "LARS":
            optimizer = LARS(
                optimizer,
                trust_coefficient=cfg_hyp.optim_modification.trust_coefficient,
                clip=False,
                eps=cfg_hyp.optim_modification.eps,
            )
        elif cfg_hyp.optim_modification.name == "LARC":
            optimizer = LARS(
                optimizer,
                trust_coefficient=cfg_hyp.optim_modification.trust_coefficient,
                clip=True,
                eps=cfg_hyp.optim_modification.eps,
            )
        elif cfg_hyp.optim_modification.name == "SAM":
            optimizer = SAM(optimizer, rho=cfg_hyp.optim_modification.rho)
        optimizer_to_schedule = optimizer.optim

    if cfg_hyp.scheduler == "linear":
        # Drop at 5/8, 6/8, 7/8:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer_to_schedule,
            milestones=[cfg_hyp.steps // 2.667, cfg_hyp.steps // 1.6, cfg_hyp.steps // 1.142],
            gamma=0.1,
        )
    elif cfg_hyp.scheduler == "exponential":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer_to_schedule, gamma=0.99)
    elif cfg_hyp.scheduler == "cosine-decay-floored":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_to_schedule, cfg_hyp.steps, eta_min=cfg_hyp.optim.lr / 25
        )
    elif cfg_hyp.scheduler == "cosine-decay":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_to_schedule, cfg_hyp.steps, eta_min=0.0)
    elif cfg_hyp.scheduler == "cosine-4000":
        # Cosine decay, hardcoded to 4000 steps
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_to_schedule, 4000, eta_min=0.0)
    elif cfg_hyp.scheduler in ["", " ", None]:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_to_schedule, milestones=[], gamma=1)
    else:
        raise ValueError(f"Invalid scheduler {scheduler} provided.")

    if cfg_hyp.warmup > 0:
        scheduler = GradualWarmupScheduler(
            optimizer_to_schedule, multiplier=1.0, total_epoch=cfg_hyp.warmup, after_scheduler=scheduler
        )

    return optimizer, scheduler
