import math
import torch
from torch.optim import Optimizer


def anneal_function(fun: str, step: int, k: float, t0: float, w: float):
    """
    Calculate the annealing curve for lambda(t) within the range [0, w]
    - sigmoid: w * sigmoid(k * (t - t0))
    - linear : min(w * t / t0, w)
    - cos    : w * (1 - cos(pi * min(1, t / t0))) / 2
    - constant: w
    """
    if w <= 0:
        return 0.0
    if fun == "sigmoid":
        return float(w / (1.0 + math.exp(-k * (step - t0))))
    elif fun == "linear":
        return float(min(w * step / max(t0, 1.0), w))
    elif fun == "cos":
        ratio = min(step / max(t0, 1.0), 1.0)
        return float(w * (1.0 - math.cos(math.pi * ratio)) / 2.0)
    elif fun == "constant":
        return float(w)
    else:
        raise ValueError(f"Unknown anneal_fun: {fun}")


class RecAdam(Optimizer):
    """
    RecAdam: Enhances the Adam optimizer by incorporating a quadratic penalty for pretrained parameters, and combines the objective task with "pullback" intensity using an annealing coefficient.

    Key Concept:
    - Loss = lambda(t) * L_target + (1 - lambda(t)) * (gamma/2) * sum_i (theta_i - theta_i^*)^2

    Required Inputs:
    - params: Training parameters (list)
    - pretrain_params: "Pretraining snapshots" corresponding one-to-one with the above params (list of tensors)

    Additional Hyperparameters:
    - pretrain_cof: Gamma, the quadratic penalty coefficient (default 5000)
    - anneal_*: Annealing function and shape (fun/k/t0/w)
    """

    def __init__(
        self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
        weight_decay=0.0, correct_bias=True,
        anneal_fun='sigmoid', anneal_k=0.1, anneal_t0=500, anneal_w=1.0,
        pretrain_cof=5000.0, pretrain_params=None
    ):
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[1]}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")

        defaults = dict(
            lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias,
            anneal_fun=anneal_fun, anneal_k=anneal_k, anneal_t0=anneal_t0, anneal_w=anneal_w,
            pretrain_cof=pretrain_cof, pretrain_params=pretrain_params
        )
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            # Stay consistent with the official implementation: ensure that `group["params"]` corresponds one-to-one with `group["pretrain_params"]`
            assert "pretrain_params" in group and group["pretrain_params"] is not None, \
                "RecAdam requires pretrain_params for each param group."
            for p, pp in zip(group["params"], group["pretrain_params"]):
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        "RecAdam/Adam does not support sparse gradients.")

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1

                # m, v
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                denom = exp_avg_sq.sqrt().add_(group["eps"])

                step_size = group["lr"]
                if group["correct_bias"]:
                    bias_correction1 = 1.0 - beta1 ** state["step"]
                    bias_correction2 = 1.0 - beta2 ** state["step"]
                    step_size = step_size * \
                        math.sqrt(bias_correction2) / \
                        max(bias_correction1, 1e-12)

                if group["anneal_w"] > 0.0:
                    anneal_lambda = anneal_function(
                        group["anneal_fun"], state["step"], group["anneal_k"],
                        group["anneal_t0"], group["anneal_w"]
                    )
                    if state["step"] % 100 == 0:
                        print(
                            f"[RecAdam] step={state['step']} lambda={anneal_lambda:.4f}")
                    # First perform gradient descent on the target task (multiplied by lambda)
                    p.data.addcdiv_(exp_avg, denom, value=-
                                    step_size * anneal_lambda)
                    # Apply a pull-back on the quadratic term of (theta - theta*) by multiplying it with (w - lambda) * gamma
                    # Note: pp may be on the CPU; ensure computation occurs on the current device
                    p.data.add_(p.data - pp.data.to(p.data.device),
                                alpha=-group["lr"] * (group["anneal_w"] - anneal_lambda) * group["pretrain_cof"])
                else:
                    p.data.addcdiv_(exp_avg, denom, value=-step_size)

                # decoupled weight decay
                if group["weight_decay"] > 0.0:
                    p.data.add_(p.data, alpha=-
                                group["lr"] * group["weight_decay"])

        return loss
