import warnings

import torch
from beartype import beartype
from torch.autograd import Function
from torch import nn, Tensor
import torch.nn.functional as F

BF16 = torch.BFloat16Tensor | torch.cuda.BFloat16Tensor
Int16 = torch.ShortTensor | torch.cuda.ShortTensor
FP32 = torch.FloatTensor | torch.cuda.FloatTensor


@beartype
@torch.compiler.disable
def bit_mask_bf16(x: BF16, bit_mask: BF16) -> BF16:
    return torch.bitwise_and(x.view(torch.int16), bit_mask.view(torch.int16)).view(torch.bfloat16)


@beartype
def apply_mask_bf16(
        x: Tensor,
        outer_clamp_val: float | None,
        inner_clamp_val: float | None,
        bit_mask: BF16 | None,
) -> BF16:
    # Be careful not to use in-place operations.
    x = x.to(torch.bfloat16)
    if outer_clamp_val is not None:
        x = torch.clamp(
            input=x,
            min=-outer_clamp_val,
            max=outer_clamp_val,
        )
    if inner_clamp_val is not None:
        x = x * (torch.clamp(x, min=-inner_clamp_val, max=inner_clamp_val) != x)
    if bit_mask is not None:
        x = bit_mask_bf16(x=x, bit_mask=bit_mask)
    return x


@beartype
class _MaskedLinearBF16(Function):
    @staticmethod
    def forward(  # noqa
            ctx,  # ctx is always the first argument to forward
            input: BF16 | FP32,
            weight: BF16 | FP32,
            bias: Tensor | None = None,  # Not expected to be used.
            clamp_outer_forward: float | None = None,
            clamp_outer_backward: float | None = None,
            clamp_inner_forward: float | None = None,
            clamp_inner_backward: float | None = None,
            bit_mask_forward: BF16 | None = None,
            bit_mask_backward: BF16 | None = None,
    ) -> BF16:  # Assumes that `bias` is always `None`.
        # The state must be preserved for backprop without modification.
        ctx.save_for_backward(input, weight, bias, bit_mask_backward)
        ctx.clamp_outer_backward = clamp_outer_backward
        ctx.clamp_inner_backward = clamp_inner_backward

        input = apply_mask_bf16(
            input,
            outer_clamp_val=clamp_outer_forward,
            inner_clamp_val=clamp_inner_forward,
            bit_mask=bit_mask_forward,
        )
        weight = apply_mask_bf16(
            weight,
            outer_clamp_val=clamp_outer_forward,
            inner_clamp_val=clamp_inner_forward,
            bit_mask=bit_mask_forward,
        )
        out = F.linear(input=input, weight=weight, bias=bias)
        return apply_mask_bf16(
            out,
            outer_clamp_val=clamp_inner_forward,
            inner_clamp_val=clamp_inner_forward,
            bit_mask=bit_mask_forward,
        )

    @staticmethod
    def backward(ctx, grad_output: Tensor):  # noqa
        input, weight, bias, bit_mask_backward = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        input = apply_mask_bf16(
            input,
            outer_clamp_val=ctx.clamp_outer_backward,
            inner_clamp_val=ctx.clamp_inner_backward,
            bit_mask=bit_mask_backward,
        )
        weight = apply_mask_bf16(
            weight,
            outer_clamp_val=ctx.clamp_outer_backward,
            inner_clamp_val=ctx.clamp_inner_backward,
            bit_mask=bit_mask_backward,
        )
        grad_output = apply_mask_bf16(
            grad_output,
            outer_clamp_val=ctx.clamp_outer_backward,
            inner_clamp_val=ctx.clamp_inner_backward,
            bit_mask=bit_mask_backward,
        )

        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight
            grad_input = apply_mask_bf16(
                grad_input,
                outer_clamp_val=ctx.clamp_outer_backward,
                inner_clamp_val=ctx.clamp_inner_backward,
                bit_mask=bit_mask_backward,
            )
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.transpose(-2, -1) @ input
            grad_weight = apply_mask_bf16(
                grad_weight,
                outer_clamp_val=ctx.clamp_outer_backward,
                inner_clamp_val=ctx.clamp_inner_backward,
                bit_mask=bit_mask_backward,
            )
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=0)

        return grad_input, grad_weight, grad_bias, *[None for _ in range(6)]


@beartype
class MaskedLinearBF16(nn.Module):  # For layers that are not autocast, simply do not use this layer.
    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias: bool = False,
            exp_bits: int = 8,
            sig_bits: int = 7,
            disable_inner_clamp: bool = False,
            disable_outer_clamp: bool = False,
            little_endian: bool = True,
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros((out_features, in_features)))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

        assert (0 < exp_bits <= 8) and (0 < sig_bits <= 7)
        if (exp_bits == 8) and (sig_bits == 7):
            msg = "Either `exp_bits` or `sig_bits` must be reduced!"
            # raise ValueError(msg)
            warnings.warn(msg)

        # Significand includes implicit bit.
        if exp_bits < 8:
            eps = torch.finfo(torch.bfloat16).smallest_normal
            exp_min = -(2 ** (exp_bits - 1)) + 2
            exp_max = (2 ** (exp_bits - 1)) - 1
            self.outer_clamp_val = (2 - (2 ** -sig_bits)) * (2 ** exp_max)
            self.inner_clamp_val = 2 ** exp_min - eps
        else:
            self.outer_clamp_val = self.inner_clamp_val = None

        if sig_bits < 7:
            # The alignment bit at the front must be kept.
            mask_uint8 = 255 - ((2 ** (7 - sig_bits)) - 1)
            mask_uint8 = torch.tensor(mask_uint8, dtype=torch.uint8)
            mask = torch.tensor([-1], dtype=torch.int16)
            # Indexing is due to the little endianness.
            if little_endian:
                mask_uint8_idx = -2
            else:
                raise NotImplementedError("Are you really using big endian?")
                # mask_uint8_idx = 1
            mask.view(torch.uint8)[..., mask_uint8_idx] = mask_uint8
            # Buffer saved in bf16 because NCCL cannot handle int16 data type.
            mask = mask.squeeze().view(torch.bfloat16)
        else:
            mask = None

        self.register_buffer('mask', mask)
        if disable_inner_clamp:
            self.inner_clamp_val = None
            warnings.warn("Disabling inner clamp values.")
        if disable_outer_clamp:
            self.outer_clamp_val = None
            warnings.warn("Disabling outer clamp values.")

    def forward(self, x: Tensor):
        return _MaskedLinearBF16.apply(
            x,  # input
            self.weight,  # weight
            self.bias,  # bias
            self.outer_clamp_val,  # clamp_outer_forward
            self.outer_clamp_val,  # clamp_outer_backward
            self.inner_clamp_val,  # clamp_inner_forward
            self.inner_clamp_val,  # clamp_inner_backward
            self.mask,  # bit_mask_forward
            self.mask,  # bit_mask_backward
        )

    @classmethod
    def mask_linear_layers(
            cls,
            module: nn.Module,
            exp_bits: int = 8,
            sig_bits: int = 7,
            disable_inner_clamp: bool = False,
            disable_outer_clamp: bool = False,
            little_endian: bool = True,
    ):
        module_output = module
        if isinstance(module, nn.Linear):
            module_output = MaskedLinearBF16(
                in_features=module.in_features,
                out_features=module.out_features,
                bias=module.bias is not None,
                exp_bits=exp_bits,
                sig_bits=sig_bits,
                disable_inner_clamp=disable_inner_clamp,
                disable_outer_clamp=disable_outer_clamp,
                little_endian=little_endian,
            )
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        for name, child in module.named_children():
            module_output.add_module(
                name,
                cls.mask_linear_layers(
                    child,
                    exp_bits=exp_bits,
                    sig_bits=sig_bits,
                    disable_inner_clamp=disable_inner_clamp,
                    disable_outer_clamp=disable_outer_clamp,
                    little_endian=little_endian,
                )
            )
        del module
        return module_output
