from typing import Iterable
import torch
from torch.optim._multi_tensor import SGD
import torch.optim
from . import LegacyFairseqOptimizer, register_optimizer

__all__ = ["SAMSGD"]

import logging
import math
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import Any, List
import random
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.optim
from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim.fused_adam import get_fused_adam_class
from omegaconf import II, OmegaConf


logger = logging.getLogger(__name__)


@dataclass
class FairseqSAMConfig(FairseqDataclass):
    momentum: float = field(
        default=0, metadata={"help": "epsilon for Adam optimizer"}
    )
    dampening: float = field(
        default=0, metadata={"help": "epsilon for Adam optimizer"}
    )
    adam_betas: Any = field(
        default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"}
    )
    adam_eps: float = field(
        default=1e-8, metadata={"help": "epsilon for Adam optimizer"}
    )
    rho: float = field(
        default=0.05, metadata={"help": "epsilon for Adam optimizer"}
    )
    beta: float = field(
        default=0.5, metadata={"help": "epsilon for Adam optimizer"}
    )
    gamma: float = field(
        default=0.5, metadata={"help": "epsilon for Adam optimizer"}
    )
    num_samples: int = field(
        default=32, metadata={"help": "epsilon for Adam optimizer"}
    )
    keep_ratio: float = field(
        default=0.1, metadata={"help": "epsilon for Adam optimizer"}
    )
    mask_iter_e: float = field(
        default=1, metadata={"help": "epsilon for Adam optimizer"}
    )
    amsgrad: bool = field(
        default=False, metadata={"help": "whether using amsgrad"}
    )
    sam_type: str = field(
        default="sam", metadata={"help": "sam type"}
    )
    weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
    nesterov: bool = field(
        default=False, metadata={"help": "Use fairseq.optim.adam.Adam"}
    )
    # TODO common vars below in parent
    tpu: bool = II("common.tpu")
    lr: List[float] = II("optimization.lr")


@register_optimizer("samsgd", dataclass=FairseqSAMConfig)
class FairseqSAM(FairseqOptimizer):
    """Adam optimizer for fairseq.

    Important note: this optimizer corresponds to the "AdamW" variant of
    Adam in its weight decay behavior. As such, it is most closely
    analogous to torch.optim.AdamW from PyTorch.
    """

    def __init__(self, cfg: FairseqSAMConfig, params, model=None):
        super().__init__(cfg)
        # self._optimizer = SAMSGD(params, **self.optimizer_config)
        # base_optimizer=torch.optim.SGD
        base_optimizer=torch.optim.Adam
        # self._optimizer = SAM(params, base_optimizer, **self.optimizer_config)
        self._optimizer = self.create_optimizer(cfg, params, base_optimizer, model=model, **self.optimizer_config)

    @property
    def optimizer_config(self):
        """
        Return a kwarg dictionary that will be used to override optimizer
        args stored in checkpoints. This allows us to load a checkpoint and
        resume training using a different set of optimizer args, e.g., with a
        different learning rate.
        """
        return {
            "weight_decay": self.cfg.weight_decay,
            "lr": self.cfg.lr[0]
            if isinstance(self.cfg.lr, Collection)
            else self.cfg.lr,
            "betas": eval(self.cfg.adam_betas)
            if isinstance(self.cfg.adam_betas, str)
            else OmegaConf.to_container(self.cfg.adam_betas),
            "eps": self.cfg.adam_eps,
            "amsgrad":self.cfg.amsgrad
        }

    def average_params(self):
        """Reduce Params is only used during BMUF distributed training."""
        state_dict = self.optimizer.state_dict()
        total_gpus = float(dist.get_world_size())

        for _, value in state_dict["state"].items():
            value["exp_avg"] /= total_gpus
            value["exp_avg_sq"] /= total_gpus
            dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM)
            dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM)

    def create_optimizer(self, args, params, base_optimizer, model=None, **kwargs):
        optimizer = Adaptive_SAM(params, base_optimizer=base_optimizer,rho=args.rho, **kwargs)
        return optimizer

class Adaptive_SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(Adaptive_SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None, loss_before=None,input_samples=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass
        if loss_before is not None:
            loss=loss_before.detach()
        else:
            loss, sample_size, logging_output = closure()
            loss=loss.detach()

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm