import math
import torch
from torch.optim import Optimizer

import time
class AdamW(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0.01,
        correct_bias=True,
        model=None,
    ):
        if model is None:
            raise ValueError("model must be provided for CompAct optimizer")
        
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            correct_bias=correct_bias,
        )
        super().__init__(params, defaults)
        self.model = model

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                if "step" not in state:
                    state["step"] = 0
                state["step"] += 1
                step = state["step"]
                if group.get("compact", False):
                    module_name = group.get("module_name")
                    if module_name is None:
                        raise RuntimeError("Compact group missing module_name!")
                    module = None
                    for name, mod in self.model.named_modules():
                        if name == module_name:
                            module = mod
                            break
                    if module is None:
                        raise RuntimeError(f"Cannot find module {module_name} in model!")
                    if not hasattr(module, 'hat_G'):
                        raise RuntimeError(f"Module {module_name} doesn't have hat_G!")
                    hat_G = module.hat_G
                    if hat_G is None:
                        continue
                    if not hasattr(module, 'P') or module.P is None:
                        raise RuntimeError(f"Module {module_name} doesn't have P!")
                    P = module.P
                    if "exp_avg" not in state:
                        state["exp_avg"] = torch.zeros_like(hat_G)
                        state["exp_avg_sq"] = torch.zeros_like(hat_G)

                    exp_avg = state["exp_avg"]
                    exp_avg_sq = state["exp_avg_sq"]
                    beta1, beta2 = group["betas"]

                    exp_avg.mul_(beta1).add_(hat_G, alpha=1 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(hat_G, hat_G, value=1 - beta2)

                    bias_correction1 = 1 - beta1 ** step
                    bias_correction2 = 1 - beta2 ** step
                    step_size = group["lr"]
                    if group["correct_bias"]:
                        step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                    denom = exp_avg_sq.sqrt().add_(group["eps"])
                    rho = exp_avg / bias_correction1 / denom 

                    update_full = group["alpha"] * torch.mm(P, rho)
                    p.add_(update_full.t(), alpha=-group["lr"])

                    if group["weight_decay"] > 0.0:
                        p.add_(p, alpha=-group["lr"] * group["weight_decay"])

                    module.hat_G = None
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("AdamW does not support sparse gradients")

                exp_avg = state.setdefault("exp_avg", torch.zeros_like(grad))
                exp_avg_sq = state.setdefault("exp_avg_sq", torch.zeros_like(grad))

                beta1, beta2 = group["betas"]
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                denom = exp_avg_sq.sqrt().add_(group["eps"])

                bias_correction1 = 1 - beta1 ** step
                bias_correction2 = 1 - beta2 ** step
                step_size = group["lr"]
                if group["correct_bias"]:
                    step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

                p.addcdiv_(exp_avg, denom, value=-step_size)

                if group["weight_decay"] > 0.0:
                    p.add_(p, alpha=-group["lr"] * group["weight_decay"])
        return loss