import math

import numpy as np
import torch
from torch import optim


def get_optimizer(parameters, lr, args, state=None):
    if args.optimizer == "sgd":
        optimizer = get_sgd_optimizer(
            parameters,
            lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=False,
            bn_weight_decay=args.bn_weight_decay,
        )
    elif args.optimizer == "rmsprop":
        optimizer = get_rmsprop_optimizer(
            parameters,
            lr,
            alpha=args.rmsprop_alpha,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            eps=args.rmsprop_eps,
            bn_weight_decay=args.bn_weight_decay,
        )
    if not state is None:
        optimizer.load_state_dict(state)

    return optimizer


def get_sgd_optimizer(
    parameters, lr, momentum, weight_decay, nesterov=False, bn_weight_decay=False
):
    if bn_weight_decay:
        print(" ! Weight decay applied to BN parameters ")
        params = [v for n, v in parameters]
    else:
        print(" ! Weight decay NOT applied to BN parameters ")
        bn_params = [v for n, v in parameters if "bn" in n]
        rest_params = [v for n, v in parameters if not "bn" in n]
        print(len(bn_params))
        print(len(rest_params))

        params = [
            {"params": bn_params, "weight_decay": 0},
            {"params": rest_params, "weight_decay": weight_decay},
        ]

    optimizer = torch.optim.SGD(
        params, lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov
    )

    return optimizer


def get_rmsprop_optimizer(
    parameters, lr, alpha, weight_decay, momentum, eps, bn_weight_decay=False
):
    bn_params = [v for n, v in parameters if "bn" in n]
    rest_params = [v for n, v in parameters if not "bn" in n]

    params = [
        {"params": bn_params, "weight_decay": weight_decay if bn_weight_decay else 0},
        {"params": rest_params, "weight_decay": weight_decay},
    ]

    optimizer = torch.optim.RMSprop(
        params,
        lr=lr,
        alpha=alpha,
        weight_decay=weight_decay,
        momentum=momentum,
        eps=eps,
    )

    return optimizer


def lr_policy(lr_fn):
    def _alr(optimizer, iteration, epoch):
        lr = lr_fn(iteration, epoch)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        return lr

    return _alr


def lr_step_policy(base_lr, steps, decay_factor, warmup_length):
    def _lr_fn(iteration, epoch):
        if epoch < warmup_length:
            lr = base_lr * (epoch + 1) / warmup_length
        else:
            lr = base_lr
            for s in steps:
                if epoch >= s:
                    lr *= decay_factor
        return lr

    return lr_policy(_lr_fn)


def lr_linear_policy(base_lr, warmup_length, epochs):
    def _lr_fn(iteration, epoch):
        if epoch < warmup_length:
            lr = base_lr * (epoch + 1) / warmup_length
        else:
            e = epoch - warmup_length
            es = epochs - warmup_length
            lr = base_lr * (1 - (e / es))
        return lr

    return lr_policy(_lr_fn)


def lr_cosine_policy(base_lr, warmup_length, epochs, end_lr=0):
    def _lr_fn(iteration, epoch):
        if epoch < warmup_length:
            lr = base_lr * (epoch + 1) / warmup_length
        else:
            e = epoch - warmup_length
            es = epochs - warmup_length
            lr = end_lr + (0.5 * (1 + np.cos(np.pi * e / es)) * (base_lr - end_lr))
        return lr

    return lr_policy(_lr_fn)


def lr_exponential_policy(
    base_lr,
    warmup_length,
    epochs,
    final_multiplier=0.001,
    decay_factor=None,
    decay_step=1,
    logger=None,
):
    """Exponential lr policy. Setting decay factor parameter overrides final_multiplier"""
    es = epochs - warmup_length

    if decay_factor is not None:
        epoch_decay = decay_factor
    else:
        epoch_decay = np.power(
            2, np.log2(final_multiplier) / math.floor(es / decay_step)
        )

    def _lr_fn(iteration, epoch):
        if epoch < warmup_length:
            lr = base_lr * (epoch + 1) / warmup_length
        else:
            e = epoch - warmup_length
            lr = base_lr * (epoch_decay ** math.floor(e / decay_step))
        return lr

    return lr_policy(_lr_fn, logger=logger)
