import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt

"""
Adam with Constrained Parameter Regularization (CPR)
current version: 11.09.23

default usage:

adam_csd_config = {             # options
    "mode": "std_constrain",    # std_constrain, var_constrain, l2_constrain
    "kappa": 0.2,               # max value of the constraint depends on the mode (defaults: 0.2 for std, 0.05 for l2)
    "kappa_factor": 2,          # factor to multiply the initial constraint value to get the constrain (is below kappa) (1 is default, for large models up to 10)
    "decay": 0,
    "lagmul_rate": 1            # rate of the lagrange multiplier update something btween 0.01 and 10 (1 is default, for large models 0.1 or 0.01)
}

parameters = group_testbed_parameters(model, adam_csd_config, avoid_keywords=['bias,norm'])

optimizer = AdamCPR(parameters, lr=self.learning_rate, betas=(0.9, 0.999), 
        kappa=adam_csd_config["kappa"],
        kappa_factor=adam_csd_config["kappa_factor"],
        decay=self.adam_csd_config["decay"],
        mode=adam_csd_config["mode"],
        apply_decay=True, 
        lagmul_rate=adam_csd_config["lagmul_rate"])

"""


class AdamAWD(Optimizer):
    def __init__(
            self,
            params,
            lr,  # =1e-3,
            betas,  # =(0.9, 0.999),
            apply_decay,  # =True,
            weight_decay,
            eps=1e-8,
            amsgrad=False,
            *,
            maximize: bool = False,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            apply_decay=apply_decay,
            decay=weight_decay,
            amsgrad=amsgrad,
            maximize=maximize,
            differentiable=False,
        )
        super().__init__(params, defaults)

    def __setstate__(self, state):  # TODO double check this
        super().__setstate__(state)
        for group in self.param_groups:
            group.setdefault("amsgrad", False)
            group.setdefault("maximize", False)
        state_values = list(self.state.values())
        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
            state_values[0]["step"]
        )
        if not step_is_tensor:
            for s in state_values:
                s["step"] = torch.tensor(float(s["step"]))

    def _init_group(
            self,
            group,
            params_with_grad,
            grads,
            amsgrad,
            lambda_dash,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
    ):
        for p in group["params"]:
            if p.grad is None:
                continue
            params_with_grad.append(p)
            if p.grad.is_sparse:
                raise RuntimeError("AdamW does not support sparse gradients")
            grads.append(p.grad)

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = (
                    torch.tensor(0.0)
                )
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(
                    p, memory_format=torch.preserve_format
                )
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(
                    p, memory_format=torch.preserve_format
                )
                # Exponential moving average of squared gradient values
                state["lambda_dash"] = torch.tensor(0.0)
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state["max_exp_avg_sq"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )

            lambda_dash.append(state["lambda_dash"])
            exp_avgs.append(state["exp_avg"])
            exp_avg_sqs.append(state["exp_avg_sq"])

            if amsgrad:
                max_exp_avg_sqs.append(state["max_exp_avg_sq"])

            state_steps.append(state["step"])

    @_use_grad_for_differentiable
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        self._cuda_graph_capture_health_check()

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            lambda_dashs = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            amsgrad = group["amsgrad"]
            apply_decay = group["apply_decay"]
            beta1, beta2 = group["betas"]

            self._init_group(
                group,
                params_with_grad,
                grads,
                amsgrad,
                lambda_dashs,
                exp_avgs,
                exp_avg_sqs,
                max_exp_avg_sqs,
                state_steps,
            )

            params = params_with_grad
            lr = group["lr"]
            decay = group["decay"]
            eps = group["eps"]
            maximize = group["maximize"]
            grad_scale = getattr(self, "grad_scale", None)
            found_inf = getattr(self, "found_inf", None)

            if not all(isinstance(t, torch.Tensor) for t in state_steps):
                raise RuntimeError(
                    "API has changed, `state_steps` argument must contain a list of singleton tensors"
                )

            if len(params) == 0:
                return

            assert grad_scale is None and found_inf is None

            for i, param in enumerate(params):
                grad = grads[i] if not maximize else -grads[i]
                lambda_dash = lambda_dashs[i]
                exp_avg = exp_avgs[i]
                exp_avg_sq = exp_avg_sqs[i]
                step_t = state_steps[i]

                if torch.is_complex(param):
                    grad = torch.view_as_real(grad)
                    exp_avg = torch.view_as_real(exp_avg)
                    exp_avg_sq = torch.view_as_real(exp_avg_sq)
                    param = torch.view_as_real(param)

                # update step
                step_t += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                step = _get_value(step_t)

                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step

                step_size = lr / bias_correction1

                bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
                else:
                    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

                if apply_decay:
                    lambda_curr = torch.norm(grad, p=2) * decay / torch.norm(param, p=2)
                    lambda_dash = 0.1 * lambda_dash + 0.9 * lambda_curr
                    param.mul_(1 - lr * lambda_dash * param)

                param.addcdiv_(exp_avg, denom, value=-step_size)  # param = param + exp_avg / denom * (-step_size)

        return loss


def group_testbed_parameters(model, optim_hps, avoid_keywords,
                             bias_regularization=False,
                             normalization_regularization=False):  # TODO I SHOULD USE THIS AS WELL
    """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with
    attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for
    normalization parameters if normalization_weight_decay==False
    """
    if not avoid_keywords:
        avoid_keywords = []

    apply_decay = set()
    apply_no_decay = set()
    special = set()
    whitelist_weight_modules = (nn.Linear, nn.Conv2d)
    blacklist_weight_modules = (nn.Embedding,)
    if not normalization_regularization:
        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
                                     nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,
                                     nn.GroupNorm, nn.SyncBatchNorm,
                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
                                     nn.LayerNorm, nn.LocalResponseNorm)

    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
    for mn, m in model.named_modules():
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
            # In case of parameter sharing, some parameters show up here but are not in
            # param_dict.keys()
            if not p.requires_grad or fpn not in param_dict:
                continue  # frozen weights
            if hasattr(p, '_optim'):
                special.add(fpn)
            elif not bias_regularization and pn.endswith('bias'):
                apply_no_decay.add(fpn)
            elif any([keyword in fpn for keyword in avoid_keywords]):
                apply_no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                apply_decay.add(fpn)
            elif isinstance(m, blacklist_weight_modules):
                apply_no_decay.add(fpn)

    apply_decay |= (param_dict.keys() - apply_no_decay - special)

    # validate that we considered every parameter
    inter_params = apply_decay & apply_no_decay
    union_params = apply_decay | apply_no_decay
    assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both apply_decay/apply_no_decay sets!"
    assert len(
        param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)}  were not separated into either apply_decay/apply_no_decay set!"

    if not apply_no_decay:
        param_groups = [{"params": [param_dict[pn] for pn in sorted(list(apply_no_decay | apply_decay))],
                         "names": [pn for pn in sorted(list(apply_no_decay | apply_decay))], **optim_hps}]
    else:

        no_wd_optim_hps = {**optim_hps}
        no_wd_optim_hps["weight_decay"] = 0.0

        param_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(apply_decay))],
             "names": [pn for pn in sorted(list(apply_decay))], "apply_decay": True, **optim_hps},
            {"params": [param_dict[pn] for pn in sorted(list(apply_no_decay))],
             "names": [pn for pn in sorted(list(apply_no_decay))], "apply_decay": False, **no_wd_optim_hps},
        ]
    # Add parameters with special hyperparameters
    # Unique dicts
    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]
    for hp in hps:
        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]
        param_groups.append({"params": params, **hp})

    return param_groups
