import torch
import torch.nn as nn
import torch.nn.functional as F
from qdiff.base.quant_layer import QuantizedLinear
from qdiff.quarot.quarot_utils import random_hadamard_matrix


class OursQuantizedLinear(QuantizedLinear):
    """
    the base quantized linear layer,
    adpot the static weight quantization,
    and the dynamic activation quantization.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool,
        device: None,
        quant_config: dict,
        fp_module: torch.nn.Linear,
    ) -> None:
        super().__init__(in_features, out_features, bias, device, quant_config, fp_module)

        self.alpha = quant_config.ours.alpha
        self.channel_mask = None  # assigned outside, during PTQ 
        self.rotation_matrix = None   # init so could be load in quant_params

        self.is_calib = False
        self.scaled_flag = False
        self.x_channel_mask = None
        self.rotate_scale_matrix = None
        self.reorder_index = None
        self.init_weights = False

    # def get_channel_mask(self, act_mask):  # feed in the act channel-wise max mask
    #     # generate the weight mask
    #     weight_mask = self.fp_module.weight.abs().max(dim=0)[0] # [C_in]
    #     channel_mask = (weight_mask.abs()**self.alpha) / (act_mask.abs()**(1-self.alpha)) # negative value with **alpha will raise nan
    #     self.channel_mask = channel_mask
    #     assert not torch.isnan(channel_mask.any()), "nan exists in channel mask"
            
    def get_rotation_matrix(self):
        self.rotation_matrix = random_hadamard_matrix(self.in_features, "cuda")

    # def update_quantized_weight_rotated_and_scaled(self):
        
    #     # INFO: apply the scaling first, the apply rotation
    #     assert self.channel_mask is not None
    #     C_out, C_in = self.fp_module.weight.shape
    #     self.w_quantizer.init_done = False   # unset the init done to overwrite quant_params

    #     self.weight.data = self.fp_module.weight / self.channel_mask.reshape([1, C_in])
    #     self.weight.data = self.w_quantizer(torch.matmul(self.weight.data, self.rotation_matrix).float())
        
    #     self.w_quantizer.init_done = True

    def update_weights_rotation(self):
        self.weight.data = torch.matmul(self.weight.data.float(), self.rotation_matrix)
    
    def update_weights_scale(self):
        weight_mask = self.weight.abs().max(dim=0)[0] # [C_in]
        self.channel_mask = (weight_mask.abs()**self.alpha) / (self.x_channel_mask.abs()**(1-self.alpha)) # negative value with **alpha will raise nan
        # self.weight.data = self.w_quantizer(self.weight.data / self.channel_mask.reshape([1, self.weight.data.shape[1]]))
        self.weight.data = self.weight.data / self.channel_mask.reshape([1, self.weight.data.shape[1]])

        self.rotate_scale_matrix = (self.rotation_matrix * self.channel_mask.reshape([1, self.rotation_matrix.shape[1]])).to(torch.float16)

        self.scaled_flag = True
    
    def update_quantized_weight_rotated_scaled_reorder(self):
        self.weight.data = torch.matmul(self.weight.data.float(), self.rotation_matrix)
        self.weight.data = self.weight.data / self.channel_mask.reshape([1, self.weight.data.shape[1]])
        self.weight.data = torch.index_select(self.weight.data, 1, self.reorder_index)
        self.weight.data = self.w_quantizer(self.weight.data)

        self.rotate_scale_matrix = (self.rotation_matrix * self.channel_mask.reshape([1, self.rotation_matrix.shape[1]])).to(torch.float16)
        self.is_calib = True
        self.scaled_flag = True

        self.init_weights = True
        # pass

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        input shape: [B,N_token,C]
        """
        if not self.quant_mode:  # use the FP
            return self.fp_module(x, *args, **kwargs)
        else:
            # reshape X into [G, -1] 
            dtype_ = x.dtype
            B, N_token, C = x.shape

            # x = x*self.channel_mask.reshape([1,1,C])  # first process through scale
            # x = torch.matmul(x, self.rotation_matrix).to(dtype=dtype_)

            if self.is_calib:
                if not self.scaled_flag:
                    x = torch.matmul(x.to(self.rotation_matrix.dtype), self.rotation_matrix).to(dtype=dtype_)
                    x_scale = x.view(-1, x.shape[-1]).abs().max(dim=0)[0].clamp_(min=1e-4)
                    if self.x_channel_mask is None:
                        self.x_channel_mask = x_scale
                    else:
                        self.x_channel_mask = torch.max(x_scale, self.x_channel_mask)
                else:
                    x = torch.matmul(x, self.rotate_scale_matrix)
                    # x = x*self.channel_mask.reshape([1,1,C]).to(dtype_)  # first process through scale
                    # x = torch.matmul(x.float(), self.rotation_matrix).to(dtype=dtype_)
                    pass

            x = x.reshape([B*N_token,-1])

            # quantize activationq
            if self.scaled_flag:
                if self.reorder_index is not None:
                    x = torch.index_select(x, 1, self.reorder_index)
                    x = self.a_quantizer(x)
                    # pass
            x = x.reshape([B, N_token, C])

            # # forward with dequantized weight and activation
            y = F.linear(x, self.weight.to(dtype=dtype_), self.bias, *args, **kwargs)

            # x = torch.matmul(x.float(), self.rotation_matrix).to(dtype=dtype_)
            # y = F.linear(x, self.weight.to(dtype=dtype_), self.bias, *args, **kwargs)
            # y = F.linear(x, self.weight.to(dtype=dtype_) @ self.rotation_matrix.to(dtype=dtype_), self.bias, *args, **kwargs)

            return y
