import math

import torch
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt
import torch.nn as nn

"""
Adam with Constrained Parameter Regularization (CPR)

default usage:

adam_csd_config = {             # options
    "mode": "std_constrain",    # std_constrain, l2_constrain
    "kappa": 0.2,               # max value of the constraint depends on the mode (defaults: 0.2 for std, 0.05 for l2)
    "kappa_init_dependent": 0,  # factor to multiply the initial constraint value to get the constrain (is below kappa) (1 is default, for large models up to 10)
    "kappa_init_warm_start": 0, # steps after the initial constraint value is setted to get the constrain (is below kappa) (1 is default, for large models up to 10)
    "lagmul_rate": 1            # rate of the lagrange multiplier update something btween 0.01 and 10 (1 is default, for large models 0.1 or 0.01)
}

"""

class AdamCPR(Optimizer):
    def __init__(
            self,
            params,
            lr,
            betas,
            mode,
            apply_decay,
            lagmul_rate,
            kappa=100,
            kappa_adapt=False,
            kappa_init_dependent=False,
            kappa_init_warm_start=False,
            eps=1e-8,
            amsgrad=False,
            *,
            maximize: bool = False,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

        assert kappa_adapt in [True, False]
        assert mode in ["l2_constrain", "std_constrain"]
        assert not ((kappa_init_dependent and kappa_init_warm_start) or (kappa_init_dependent != 0 and kappa_init_warm_start != 0))
        self.mode = mode
        self.kappa_adapt = kappa_adapt

        if kappa_init_warm_start == False:
            self.kappa_init_after_steps = 0
        else:
            self.kappa_init_after_steps = kappa_init_warm_start

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            apply_decay=apply_decay,
            kappa=kappa,
            kappa_factor=kappa_init_dependent,
            lagmul_rate=lagmul_rate,
            amsgrad=amsgrad,
            maximize=maximize,
            differentiable=False,
        )
        super().__init__(params, defaults)


    def __setstate__(self, state):
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)
            group.setdefault("maximize", False)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
            state_values[0]["step"]
        )
        if not step_is_tensor:
            for s in state_values:
                s["step"] = torch.tensor(float(s["step"]))

    def _init_group(
            self,
            group,
            params_with_grad,
            grads,
            amsgrad,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            lagmuls,
            kappas,
            adapt_flags,
            state_steps,
    ):
        for p in group["params"]:
            if p.grad is None:
                continue
            params_with_grad.append(p)
            if p.grad.is_sparse:
                raise RuntimeError("AdamW does not support sparse gradients")
            grads.append(p.grad)

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = (
                    torch.tensor(0.0)
                )
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(
                    p, memory_format=torch.preserve_format
                )
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(
                    p, memory_format=torch.preserve_format
                )
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state["max_exp_avg_sq"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )

                if "constrain" in self.mode and group["apply_decay"]:
                    # Exponential moving average of squared gradient values
                    # lagmul_init = torch.zeros([1], device=p.device)
                    # lagmul_init = torch.log(torch.exp(torch.sqrt(lagmul_init)) - 1)
                    state["lagmul"] = torch.tensor(0, dtype=torch.float, device=p.device)
                    state["adapt_flag"] = torch.tensor(0, dtype=torch.bool, device=p.device)

                    if self.kappa_init_after_steps > 0:
                        state["kappa"] = torch.tensor(group["kappa"], dtype=torch.float, device=p.device)
                    elif group["kappa_factor"] == False or group["kappa_factor"] == None or group["kappa_factor"] == 0:
                        state["kappa"] =  torch.tensor(group["kappa"], dtype=torch.float, device=p.device)
                    else:
                        if "std_constrain" == self.mode or "std_constrain_mh" == self.mode:
                            state["kappa"] = torch.min(group["kappa_factor"] * torch.std(p), torch.tensor(group["kappa"], dtype=torch.float, device=p.device))

                        elif "l2_constrain" == self.mode or "l2_mean_constrain" == self.mode or "l2_constrain_mh" == self.mode:
                            state["kappa"] = torch.min(group["kappa_factor"] * p.square().mean(), torch.tensor(group["kappa"], dtype=torch.float, device=p.device))




            exp_avgs.append(state["exp_avg"])
            exp_avg_sqs.append(state["exp_avg_sq"])

            if amsgrad:
                max_exp_avg_sqs.append(state["max_exp_avg_sq"])

            if "constrain" in self.mode and group["apply_decay"]:
                lagmuls.append(state["lagmul"])
                kappas.append(state["kappa"])
                adapt_flags.append(state["adapt_flag"])

            state_steps.append(state["step"])

    @_use_grad_for_differentiable
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        self._cuda_graph_capture_health_check()

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            lagmuls = []
            kappas = []
            adapt_flags = []
            state_steps = []
            amsgrad = group["amsgrad"]
            apply_decay = group["apply_decay"]
            beta1, beta2 = group["betas"]

            self._init_group(
                group,
                params_with_grad,
                grads,
                amsgrad,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                lagmuls,
                kappas,
                adapt_flags,
                state_steps,
            )

            params = params_with_grad
            lr = group["lr"]
            lagmul_rate = group["lagmul_rate"]
            eps = group["eps"]
            maximize = group["maximize"]
            grad_scale = getattr(self, "grad_scale", None)
            found_inf = getattr(self, "found_inf", None)

            if not all(isinstance(t, torch.Tensor) for t in state_steps):
                raise RuntimeError(
                    "API has changed, `state_steps` argument must contain a list of singleton tensors"
                )

            if len(params) == 0:
                return

            assert grad_scale is None and found_inf is None

            for i, param in enumerate(params):
                grad = grads[i] if not maximize else -grads[i]
                exp_avg = exp_avgs[i]
                exp_avg_sq = exp_avg_sqs[i]
                step_t = state_steps[i]

                if torch.is_complex(param):
                    grad = torch.view_as_real(grad)
                    exp_avg = torch.view_as_real(exp_avg)
                    exp_avg_sq = torch.view_as_real(exp_avg_sq)
                    param = torch.view_as_real(param)

                # update step
                step_t += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                step = _get_value(step_t)

                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step

                step_size = lr / bias_correction1

                bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
                else:
                    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

                if apply_decay:

                    lagmul = lagmuls[i]
                    kappa = kappas[i]
                    n = float(param.numel())

                    if self.mode == "l2_constrain":

                        half_sum_l2norm = param.square().sum()
                        param_specific_lagmul_rate = lagmul_rate / n
                        param_specific_kappa = kappa * n

                        constraint_value = half_sum_l2norm - param_specific_kappa
                        grad_c = 2 * param

                        lagmul.add_(param_specific_lagmul_rate * constraint_value).clip_(min=0.)
                        param.add_(-grad_c * lagmul)

                    elif self.mode == "std_constrain":

                        std_dev = param.std()
                        constraint_value = std_dev - kappa

                        mean = param.mean()
                        norm_param = param.sub(mean)
                        grad_std_dev = norm_param.mul_(2).sub_(2 * norm_param.mean()).div_(n - 1)
                        grad_std_dev.div_(std_dev.mul_(2))
                        grad_c = grad_std_dev

                        lagmul.add_(lagmul_rate * constraint_value).clip_(min=0.)
                        param.add_(-grad_c * lagmul)

                    if self.kappa_adapt and step > self.kappa_init_after_steps:
                        adapt_flag = adapt_flags[i]

                        if adapt_flag == True and lagmul == 0:
                            if self.mode == "l2_constrain":
                                new_kappa = param.square().mean()
                            elif self.mode == "std_constrain":
                                new_kappa = param.std()
                            kappa.clamp_max_(new_kappa)

                        if lagmul > 0 and adapt_flag == False:
                            adapt_flag.add_(True)

                    if self.kappa_init_after_steps == step and self.kappa_init_after_steps != 0:
                        if self.mode == "l2_constrain":
                            new_kappa = param.square().mean()
                        elif self.mode == "std_constrain":
                            new_kappa = param.std()
                        kappa.clamp_max_(new_kappa)

                param.addcdiv_(exp_avg, denom, value=-step_size)  # actual adam update

        return loss


def group_cpr_parameters(model, optim_hps, avoid_keywords, bias_regularization=False,
                         normalization_regularization=False):  # TODO I SHOULD USE THIS AS WELL
    """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
    attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for
    normalization parameters if normalization_weight_decay==False
    """
    if not avoid_keywords:
        avoid_keywords = []

    apply_decay = set()
    apply_no_decay = set()
    special = set()
    whitelist_weight_modules = (nn.Linear, nn.Conv2d)
    blacklist_weight_modules = (nn.Embedding,)
    if not normalization_regularization:
        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                                     nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
                                     nn.GroupNorm, nn.SyncBatchNorm,
                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
                                     nn.LayerNorm, nn.LocalResponseNorm)

    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
            # In case of parameter sharing, some parameters show up here but are not in
            # param_dict.keys()
            if not p.requires_grad or fpn not in param_dict:
                continue  # frozen weights
            if hasattr(p, '_optim'):
                special.add(fpn)
            elif not bias_regularization and pn.endswith('bias'):
                apply_no_decay.add(fpn)
            elif any([keyword in fpn for keyword in avoid_keywords]):
                apply_no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                apply_decay.add(fpn)
            elif isinstance(m, blacklist_weight_modules):
                apply_no_decay.add(fpn)

    apply_decay |= (param_dict.keys() - apply_no_decay - special)

    # validate that we considered every parameter
    inter_params = apply_decay & apply_no_decay
    union_params = apply_decay | apply_no_decay
    assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both apply_decay/apply_no_decay sets!"
    assert len(
        param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)}  were not separated into either apply_decay/apply_no_decay set!"

    if not apply_no_decay:
        param_groups = [{"params": [param_dict[pn] for pn in sorted(list(apply_no_decay | apply_decay))],
                         "names": [pn for pn in sorted(list(apply_no_decay | apply_decay))], **optim_hps}]
    else:
        param_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(apply_decay))],
             "names": [pn for pn in sorted(list(apply_decay))], "apply_decay": True, **optim_hps},
            {"params": [param_dict[pn] for pn in sorted(list(apply_no_decay))],
             "names": [pn for pn in sorted(list(apply_no_decay))], "apply_decay": False, **optim_hps},
        ]
    # Add parameters with special hyperparameters
    # Unique dicts
    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
    for hp in hps:
        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
        param_groups.append({"params": params, **hp})

    return param_groups
