import torch
import torch.nn as nn
import torch.nn.functional as F

from quant_utils.base.quant_layer import QuantizedLinear
from quant_utils.veta.quarot_utils import random_hadamard_matrix, matmul_hadU_cuda


class VetaQuarotQuantizedLinear(QuantizedLinear):
    
    def __init__(self, in_features: int, out_features: int, bias: bool,
                 device: None, quant_config: dict, fp_module: torch.nn.Linear) -> None:
        super().__init__(in_features, out_features, bias, device, quant_config, fp_module)

        self.kl_rotation_matrix = None   # init so could be load in quant_params
        self.alpha = quant_config.kl_quarot.alpha
        self.channel_mask = None  # assigned outside, during PTQ
    
    def get_channel_mask(self, act_mask):  # feed in the act channel-wise max mask
        # generate the weight mask
        weight_mask = self.fp_module.weight.abs().max(dim=0)[0] # [C_in]
        act_mask = act_mask.to(weight_mask.device)
        channel_mask = (weight_mask.abs()**self.alpha) / (act_mask.abs()**(1-self.alpha)) # negative value with **alpha will raise nan
        self.channel_mask = channel_mask
        assert not torch.isnan(channel_mask.any()), "nan exists in channel_mask"
        assert not torch.isinf(channel_mask.any()), "inf exists in channel_mask"
    
    def get_rotation_matrix(self, input_tensor):
        self.input_data = input_tensor
        
        self.hadamard_matrix = random_hadamard_matrix(self.in_features, "cuda")
        self.kl_matrix = self.get_kl_matrix(input_tensor).to("cuda")
        self.kl_rotation_matrix = torch.matmul(self.kl_matrix.double(), self.hadamard_matrix)
        
    def get_kl_matrix(self, tensor):
        shape = tensor.shape
        cov_matrix = torch.cov(tensor.reshape(-1, shape[-1]).double().T)
        
        eig_values, K = torch.linalg.eigh(cov_matrix)
        if (K @ K.T)[0, 0] > 0.99 and (K @ K.T)[0, 1] < 0.0001:
            print(f"input is orthogonal")
        else:
            K = self.gram_schmidt(K)
            print(f"input is not orthogonal")
            
        weight_klt = K.float()
        return weight_klt
    
    def gram_schmidt(self, K):
        n = K.size(1)
        Q = torch.zeros_like(K)
        for i in range(n):
            q = K[:, i]
            for j in range(i):
                q -= torch.dot(Q[:, j], K[:, i]) * Q[:, j]
            Q[:, i] = q / q.norm()
        return Q
    
    def update_quantized_weight_scaled(self):
        # INFO: apply the scaling
        # assert self.channel_mask is not None
        C_out, C_in = self.fp_module.weight.shape
        self.w_quantizer.init_done = False  # unset the init done to overwrite quant_params
        
        # self.weight.data = self.w_quantizer(self.fp_module.weight / self.channel_mask.reshape([1, C_in]))
        if self.channel_mask is not None:
            self.weight.data = self.fp_module.weight / self.channel_mask.reshape([1, C_in])
        else:
            self.weight.data = self.fp_module.weight
        
        self.w_quantizer.init_done = True
        
    def update_quantized_weight_rotated(self, gptq=False, batch_size=-1):
        self.w_quantizer.init_done = False   # unset the init done to overwrite quant_params
        
        if gptq:
            dtype_ = self.input_data.dtype
            # B, N_token, C = self.input_data.shape
            if len(self.input_data.shape) == 3:
                B, N_token, C = self.input_data.shape
                assert N_token % self.q_cfg.calib_batch_size == 0, f"input shape {self.input_data.shape} is not supported"
                self.input_data = self.input_data.reshape([B, self.q_cfg.calib_batch_size, N_token // self.q_cfg.calib_batch_size, C])
                self.input_data = self.input_data.reshape([B * self.q_cfg.calib_batch_size, N_token // self.q_cfg.calib_batch_size, C])
                B, N_token, C = self.input_data.shape
            elif len(self.input_data.shape) == 2:
                if batch_size == -1:
                    B = self.q_cfg.calib_batch_size
                    N_token = self.input_data.shape[0] // B
                    C = self.input_data.shape[1]
                else:
                    B = batch_size
                    N_token = self.input_data.shape[0] // B
                    C = self.input_data.shape[1]
            else:
                raise ValueError(f"input shape {self.input_data.shape} is not supported")
            if self.channel_mask is not None:
                self.input_data = self.input_data * self.channel_mask.reshape([1, 1, C]).to(self.input_data.device)
            self.input_data = torch.matmul(self.input_data.double(), self.kl_rotation_matrix.to(self.input_data.device)).to(dtype=dtype_)
            self.input_data = self.input_data.reshape([B * N_token, -1])
            self.input_data = self.input_data.reshape([B, N_token, C])
            self.w_quantizer.gptq_add_batch(self.input_data)
            # import pdb; pdb.set_trace()
        
        self.weight.data = self.w_quantizer(torch.matmul(self.weight.data.double(), self.kl_rotation_matrix).float())
        self.w_quantizer.init_done = True
        
        self.input_data = None
        del self.input_data
        torch.cuda.empty_cache()
    
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        input shape: [B,N_token,C]
        """
        # print(f"self.quant_mode: {self.quant_mode}")
        if not self.quant_mode:  # use the FP
            return self.fp_module(x, *args, **kwargs)
        else:
            # reshape X into [G, -1] 
            dtype_ = x.dtype
            B, N_token, C = x.shape
            if self.channel_mask is not None:
                x = x*self.channel_mask.reshape([1,1,C])  # first process through scale
            x = torch.matmul(x.double(), self.kl_rotation_matrix).to(dtype=dtype_)
            x = x.reshape([B*N_token,-1])

            x = self.a_quantizer(x)
            x = x.reshape([B, N_token, C])

            y = F.linear(x, self.weight.to(x.dtype), self.bias, *args, **kwargs)
            
            return y