import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn.parameter import Parameter

from .quant_utils import vectorwise_quant, vectorwise_dequant, create_normal_map
from lpmm.config import get_config


# Assumes layer is perfectly divisible into 256 * 256 blocks
class TQLinear(nn.Module): # TQ -> Trainable Quantization
    def __init__(
        self, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        # if bits not in [4]:
        #     raise NotImplementedError("Only 4 bits are supported.")
        self.in_features = in_features
        self.out_features = out_features
        self.w_shape = weight.shape

        self.qconfig = get_config(config)
        model_qmetadata = self.get_qmetadata()
        self.bits = model_qmetadata["b"]
        self.groupsize = model_qmetadata["gp_sz"]
        if weight.flatten().shape[0] % self.groupsize != 0:
            self.num_groups = weight.flatten().shape[0] // self.groupsize + 1
        else:
            self.num_groups = weight.flatten().shape[0] // self.groupsize
        # self.maxq = 2 ** self.bits - 1

        self.qmap = create_normal_map(offset=0.995, total_bits=self.bits)
        self.q_scales = Parameter(torch.empty(self.num_groups, **factory_kwargs)) # ones
        self.q_biases_neg = Parameter(torch.empty(self.num_groups, **factory_kwargs)) # inverse sigmoid
        self.q_biases_pos = Parameter(torch.empty(self.num_groups, **factory_kwargs))
        # self.init_value_for_q_scales_2 = torch.log(self.qmap[-2]/(1-self.qmap[-2]))
        if is_cuda:
            self.qmap = self.qmap.to(device)
            # self.q_scales = self.q_scales.to(device)
            # self.q_biases = self.q_biases.to(device)

        # quantize bias as well?
        self.register_buffer('q_weight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
        '''
        self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
        self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
        self.register_buffer('g_idx', torch.tensor([i // self.groupsize  for i in range(in_features)], dtype = torch.int32))
        '''
        if bias is not None:
            self.register_buffer('bias', bias) # torch.zeros((out_features),dtype=torch.float16)
        else:
            self.bias = None
        
        '''
        # is performed by unpacking the weights and using torch.matmul
        if self.bits in [2,4,8]: 
            self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
        elif self.bits == 3:
            self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
                                                     [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
                                                     [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)
        '''  
        self.is_cuda = is_cuda
        self.device = device

        self.reset_parameters(factory_kwargs)
        self.q_weight.requires_grad = False
        if bias is not None:
            self.bias.requires_grad = False

        self.q_metadata = dict()
        self.init_q_weight(weight, factory_kwargs)
        # model_qmetadata = self.get_qmetadata()
        self.q_metadata.update(model_qmetadata)

        # self.q_weight.requires_grad = False
        # if bias is not None:
        #     self.bias.requires_grad = False

    def reset_parameters(self, factory_kwargs):
        init.ones_(self.q_scales) # ones
        init.zeros_(self.q_biases_neg)
        init.zeros_(self.q_biases_pos)
        # self.q_scales_2.data = torch.tensor([self.init_value_for_q_scales_2]*self.num_groups, **factory_kwargs)

    def init_q_weight(self, weight, factory_kwargs):
        model_qmetadata = self.get_qmetadata()
        self.q_weight, gen = vectorwise_quant(weight, torch.ones(self.num_groups, **factory_kwargs), torch.zeros(self.num_groups, **factory_kwargs), torch.zeros(self.num_groups, **factory_kwargs), qmap=self.qmap, shape=self.w_shape, **model_qmetadata) 
        self.q_metadata.update(gen)

    def get_subqconfig(self):
        return self.qconfig.QUANT.M
        
    def get_qmetadata(self):
        subconfig = self.get_subqconfig()
        md = dict(
            b=subconfig.BITS,
            scale_type=subconfig.SCALE_TYPE.DEFAULT,
            quant_type=subconfig.QUANT_TYPE.DEFAULT,
            round_type=subconfig.ROUND_TYPE,
            gp_sz=subconfig.GROUP_SIZE,
            signed=subconfig.SIGNED,
        )
        return md

    def forward(self, x): # implement forward function here without Autograd4bit and compare how it is different
        if self.bits in [2, 3, 4]:
            dequant_weight = vectorwise_dequant(self.q_weight, self.q_scales, self.q_biases_neg, self.q_biases_pos, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            out = F.linear(x, dequant_weight, self.bias)
        else:
            raise NotImplementedError()
        return out

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'