import math
from typing import Callable, Optional

import attr
import torch
import torch.nn as nn
import torch.nn.functional as F

FilterFn = Callable[[torch.Tensor], torch.Tensor]


class ZeroKeyBiasGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def backward(ctx, output_grad):
        output_grad = output_grad.clone()
        output_grad.chunk(3)[1].zero_()
        return output_grad


def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor:
    return ZeroKeyBiasGrad.apply(x)


@attr.s(eq=False, repr=False)
class LayerNorm(nn.Module):
    n_state: int = attr.ib()
    eps: float = attr.ib(default=1e-6)
    device: torch.device = attr.ib(default=torch.device("cuda"))

    def __attrs_post_init__(self) -> None:
        super().__init__()
        self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device))
        self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device))
        self.g.weight_decay_level = "disable"  # type: ignore
        self.b.weight_decay_level = "disable"  # type: ignore

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(
            x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps
        )


@attr.s(eq=False, repr=False)
class Affine(nn.Module):
    n_in: int = attr.ib()
    n_out: int = attr.ib()
    use_bias: bool = attr.ib(default=True)
    use_admnet_init: bool = attr.ib(default=False)
    std: Optional[float] = attr.ib(default=None)
    extra_init_scale: Optional[float] = attr.ib(default=None)
    bias_filter_fn: FilterFn = attr.ib(default=lambda x: x)
    device: torch.device = attr.ib(default=torch.device("cuda"))

    def __attrs_post_init__(self) -> None:
        super().__init__()

        if not self.use_admnet_init:
            self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out))
            self.std = (
                self.std if self.extra_init_scale is None else self.std * self.extra_init_scale
            )

            w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)
            self.w = nn.Parameter(w)

            if self.use_bias:
                self.b = nn.Parameter(
                    torch.zeros((self.n_out,), dtype=torch.float32, device=self.device)
                )
                self.b.weight_decay_level = "disable"  # type: ignore
        else:
            if self.extra_init_scale is not None:
                raise ValueError("extra_init_scale incompatible with admnet init")

            w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device)

            if self.use_bias:
                b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device)

            self.w = nn.Parameter(w)

            if self.use_bias:
                self.b = nn.Parameter(b)
                self.b.weight_decay_level = "disable"  # type: ignore

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype)
        b = (
            self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype))
            if self.use_bias
            else None
        )
        return F.linear(x, w, b)
