import torch
import torch.nn as nn
import torch.nn.functional as F

from nvfp4_triton import rtn_1x16s_fp4_autograd, rtn_16x16s_fp4_autograd

class BaseQuantizer(nn.Module):
    def __init__(self, bits=4):
        super().__init__()
        self.bits = bits
        self.n_levels = 2**bits

    def forward(self, x):
        raise NotImplementedError

    def re_randomize(self):
        pass


class NoQuantizer(BaseQuantizer):
    def __init__(self, **kwargs):
        super().__init__(16)

    def forward(self, x):
        return x


OPTIMAL_GAUSSIAN_SCALES = {
    1: 0.7978845587140913,
    1.585: 1.2240089519030855,
    2: 1.4935346200015913,
    3: 2.051068354131873,
    4: 2.513930578568423,
    5: 2.9160938834961225,
    6: 3.276597282593217,
    7: 3.6010497188221655,
    8: 3.884938678807525,
}


def rtn_fp4(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    low = grid[lo]
    high = grid[hi]

    return torch.where(
        (high - x) <= (x - low),
        high,
        low,
    )


class Nvfp4Quantizer(BaseQuantizer):
    grid = torch.tensor(
        [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
        0.0,  0.5,  1.0,  1.5,  2.0,  3.0,  4.0, 6.0],
        device="cuda",
    )
    
    def __init__(self, hadamard_dim=1, square: bool=True, scale_override: float=1.0, four_over_six: bool=False):
        super().__init__(4)
        
        self.hadamard_dim = hadamard_dim
        if self.hadamard_dim != 1:
            self.hadamard_matrix = hadamard_transform(
                torch.eye(hadamard_dim, dtype=torch.float32, device="cuda"), scale=hadamard_dim**-0.5
            )
        self.square = square
        self.scale_override = scale_override
        self.four_over_six = four_over_six

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"hadamard_dim={self.hadamard_dim}, "
            f"square={self.square}, "
            f"scale_override={self.scale_override})"
        )
        
    def round_scales(self, scales):            
        s_dec = scales.max() / (447.99 * 6.0)
        s_dec[s_dec == 0] = 1.0
        s_dec_b = scales / 6.0
        s_dec_b_e4m3 = (s_dec_b / s_dec).to(torch.float8_e4m3fn).float()
        s_dec_b_e4m3[s_dec_b_e4m3 == 0] = 1.0
        s_enc_b_inv = s_dec_b_e4m3 * s_dec
        return s_enc_b_inv

    def forward(self, x):
        if hasattr(self, "hadamard_matrix"):
            self.hadamard_matrix = self.hadamard_matrix.to(x.device).to(x.dtype)
        self.grid = self.grid.to(x.device).to(x.dtype)
        
        if (
            self.hadamard_dim == 1 and
            not self.square
        ):
            return rtn_1x16s_fp4_autograd.apply(x, self.scale_override, 16, self.four_over_six)
        elif (
            self.hadamard_dim == 1 and
            self.square
        ):
            return rtn_16x16s_fp4_autograd.apply(x, self.scale_override, 16, self.four_over_six)
        
        assert not self.four_over_six, f"four_over_six only triton"
            
        
        if self.hadamard_dim != 1:
            x_had = F.linear(x.view(-1, self.hadamard_dim), self.hadamard_matrix).view_as(x)
        else:
            x_had = x.clone()

        with torch.no_grad():  
            if self.square:
                x_grouped = x_had.view(x.shape[0] // 16, 16, x.shape[1] // 16, 16).permute(0, 2, 1, 3).reshape(-1, 16 * 16)
            else:
                x_grouped = x_had.view(-1, 16)
            
            scales = x_grouped.abs().max(dim=-1, keepdim=True)[0]
            print("SCALES", scales)
            s_enc_b_inv = self.round_scales(scales)
            x_fp4 = rtn_fp4(x_grouped / s_enc_b_inv, self.grid) * s_enc_b_inv
            
            if self.square:    
                x_fp4 = x_fp4.reshape(x.shape[0] // 16, x.shape[1] // 16, 16, 16).permute(0, 2, 1, 3).reshape_as(x)
            else:
                x_fp4 = x_fp4.view_as(x)

        if self.hadamard_dim != 1: 
            return (x_had + (x_fp4 - x_had).detach())
        else:
            return x + (x_fp4 - x).detach()
