import torch


class AdamDecouple(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        weight_decay=0,
        eps=1e-8,
        decouple_m=True,
        decouple_v=True,
        lr_forget_ratio=1.0,
        lr_retain_ratio=1.0,
    ):
        """
        lr is uniform so that scheduler can be applied
        we adjust the learning rate for forget and retain by the ratio
        """
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {beta1}")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {beta2}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if not 0.0 <= lr_forget_ratio:
            raise ValueError(f"Invalid lr_forget_ratio value: {lr_forget_ratio}")
        if not 0.0 <= lr_retain_ratio:
            raise ValueError(f"Invalid lr_retain_ratio value: {lr_retain_ratio}")
        defaults = dict(
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            eps=eps,
            decouple_m=decouple_m,
            decouple_v=decouple_v,
            lr_forget_ratio=lr_forget_ratio,
            lr_retain_ratio=lr_retain_ratio,
        )
        super(AdamDecouple, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if group["decouple_m"]:
                    state["m1"] = torch.zeros_like(p.data)
                    state["m2"] = torch.zeros_like(p.data)
                else:
                    state["m"] = torch.zeros_like(p.data)
                if group["decouple_v"]:
                    state["v1"] = torch.zeros_like(p.data)
                    state["v2"] = torch.zeros_like(p.data)
                else:
                    state["v"] = torch.zeros_like(p.data)
                state["t"] = 0
                state["t1"] = 0
                state["t2"] = 0

        self.mode = "forget"

    @torch.no_grad()
    def step(self, closure=None, mode="forget"):
        """
        switch mode to forget or retain
        """
        mode = self.mode
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            beta1 = group["beta1"]
            beta2 = group["beta2"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]

            ratio = (
                group["lr_forget_ratio"]
                if mode == "forget"
                else group["lr_retain_ratio"]
            )

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if weight_decay != 0:
                    grad.add_(p.data, alpha=weight_decay)

                state = self.state[p]

                if group["decouple_m"]:
                    m = state["m1"] if mode == "forget" else state["m2"]
                else:
                    m = state["m"]

                if group["decouple_v"]:
                    v = state["v1"] if mode == "forget" else state["v2"]
                else:
                    v = state["v"]

                t = state["t"]
                t += 1
                state["t"] = t
                t_prim = state["t1"] if mode == "forget" else state["t2"]
                t_prim += 1
                state["t1" if mode == "forget" else "t2"] = t_prim

                m.mul_(beta1).add_(grad, alpha=1 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                bias_correction1 = (
                    1 - beta1 ** t_prim if group["decouple_m"] else 1 - beta1 ** t
                )
                bias_correction2 = (
                    1 - beta2 ** t_prim if group["decouple_v"] else 1 - beta2 ** t
                )

                step_size = lr / bias_correction1
                denom = (v.sqrt() / (bias_correction2 ** 0.5)).add_(eps)

                p.data.addcdiv_(m, denom, value=-ratio * step_size)
        return loss


class AdamWDecouple(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        weight_decay=0,
        eps=1e-8,
        decouple_m=True,
        decouple_v=True,
        lr_forget_ratio=1.0,
        lr_retain_ratio=1.0,
    ):
        """
        lr is uniform so that scheduler can be applied
        we adjust the learning rate for forget and retain by the ratio
        """
        self.t = 0
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {beta1}")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {beta2}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if not 0.0 <= lr_forget_ratio:
            raise ValueError(f"Invalid lr_forget_ratio value: {lr_forget_ratio}")
        if not 0.0 <= lr_retain_ratio:
            raise ValueError(f"Invalid lr_retain_ratio value: {lr_retain_ratio}")
        defaults = dict(
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            eps=eps,
            decouple_m=decouple_m,
            decouple_v=decouple_v,
            lr_forget_ratio=lr_forget_ratio,
            lr_retain_ratio=lr_retain_ratio,
        )
        super(AdamWDecouple, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                if group["decouple_m"]:
                    state["m1"] = torch.zeros_like(p.data)
                    state["m2"] = torch.zeros_like(p.data)
                else:
                    state["m"] = torch.zeros_like(p.data)
                if group["decouple_v"]:
                    state["v1"] = torch.zeros_like(p.data)
                    state["v2"] = torch.zeros_like(p.data)
                else:
                    state["v"] = torch.zeros_like(p.data)
                state["t"] = 0
                state["t1"] = 0
                state["t2"] = 0

        self.mode = "forget"

    def set_mode(self, mode):
        assert mode in ["forget", "retain"]
        self.mode = mode

    @torch.no_grad()
    def step(self, closure=None):
        """
        switch mode to forget or retain
        """
        self.mode = "forget" if self.t % 2 == 0 else "retain"  # Set mode based on t
        mode = self.mode
        # print(self.t, mode)
        self.t += 1  # Increment the counter

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            beta1 = group["beta1"]
            beta2 = group["beta2"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]

            ratio = (
                group["lr_forget_ratio"]
                if mode == "forget"
                else group["lr_retain_ratio"]
            )

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]
                # Ensure state variables exist
                if group["decouple_m"]:
                    state["m1"] = torch.zeros_like(p.data)
                    state["m2"] = torch.zeros_like(p.data)
                else:
                    state["m"] = torch.zeros_like(p.data)

                if group["decouple_v"]:
                    state["v1"] = torch.zeros_like(p.data)
                    state["v2"] = torch.zeros_like(p.data)
                else:
                    state["v"] = torch.zeros_like(p.data)

                state["t"] = 0
                state["t1"] = 0
                state["t2"] = 0

                if group["decouple_m"]:
                    m = state["m1"] if mode == "forget" else state["m2"]
                else:
                    m = state["m"]

                if group["decouple_v"]:
                    v = state["v1"] if mode == "forget" else state["v2"]
                else:
                    v = state["v"]

                t = state["t"]
                t += 1
                state["t"] = t
                t_prim = state["t1"] if mode == "forget" else state["t2"]
                t_prim += 1
                state["t1" if mode == "forget" else "t2"] = t_prim

                m.mul_(beta1).add_(grad, alpha=1 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                bias_correction1 = (
                    1 - beta1 ** t_prim if group["decouple_m"] else 1 - beta1 ** t
                )
                bias_correction2 = (
                    1 - beta2 ** t_prim if group["decouple_v"] else 1 - beta2 ** t
                )

                step_size = lr / bias_correction1
                denom = (v.sqrt() / (bias_correction2 ** 0.5)).add_(eps)

                if weight_decay != 0:
                    p.data.mul_(1 - ratio * lr * weight_decay)

                p.data.addcdiv_(m, denom, value=-ratio * step_size)

        return loss


class OptimizedAdamWDecouple(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        weight_decay=0,
        eps=1e-8,
        lr_forget_ratio=1.0,
        lr_retain_ratio=1.0,
    ):
        """
        A memory-efficient version of AdamWDecouple.
        Instead of maintaining separate moment estimates for forget/retain modes,
        we use a single set of moments and switch modes dynamically.
        """
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= beta1 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {beta1}")
        if not 0.0 <= beta2 < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {beta2}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if not 0.0 <= lr_forget_ratio:
            raise ValueError(f"Invalid lr_forget_ratio value: {lr_forget_ratio}")
        if not 0.0 <= lr_retain_ratio:
            raise ValueError(f"Invalid lr_retain_ratio value: {lr_retain_ratio}")

        defaults = dict(
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            eps=eps,
            lr_forget_ratio=lr_forget_ratio,
            lr_retain_ratio=lr_retain_ratio,
        )
        super(OptimizedAdamWDecouple, self).__init__(params, defaults)

        self.t = 0  # Global step counter
        self.mode = "forget"

        for group in self.param_groups:
            for p in group["params"]:
                if p.requires_grad:
                    self.state[p] = {
                        "m": torch.zeros_like(p.data),  # Single momentum buffer
                        "v": torch.zeros_like(p.data),  # Single variance buffer
                        "t_forget": 0,  # Timestep for forget
                        "t_retain": 0,  # Timestep for retain
                    }

    def set_mode(self, mode):
        assert mode in ["forget", "retain"]
        self.mode = mode

    @torch.no_grad()
    def step(self, closure=None):
        """
        Perform an optimization step in the selected mode (forget or retain).
        """
        self.t += 1
        mode = "forget" if self.t % 2 == 0 else "retain"

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            beta1 = group["beta1"]
            beta2 = group["beta2"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]
            ratio = (
                group["lr_forget_ratio"]
                if mode == "forget"
                else group["lr_retain_ratio"]
            )

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data

                # Apply weight decay
                if weight_decay != 0:
                    grad = grad.add(p.data, alpha=weight_decay)

                state = self.state[p]
                m, v = state["m"], state["v"]

                # Select timestep for forget or retain mode
                t_mode = "t_forget" if mode == "forget" else "t_retain"
                state[t_mode] += 1
                t = state[t_mode]

                # Update biased moment estimates
                m.mul_(beta1).add_(grad, alpha=1 - beta1)
                v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias correction
                bias_correction1 = 1 - beta1 ** t
                bias_correction2 = (1 - beta2 ** t) ** 0.5

                # Compute step size
                step_size = (lr / bias_correction1) * ratio
                denom = v.sqrt().div(bias_correction2).add_(eps)

                # Parameter update
                p.data.addcdiv_(m, denom, value=-step_size)

        return loss
