import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from typing import Optional, List

from .quant_utils import vectorwise_quant, vectorwise_dequant, create_normal_map
from lpmm.config import get_config


# Assumes layer is perfectly divisible into 256 * 256 blocks
class TQLinear(nn.Module): # TQ -> Trainable Quantization
    def __init__(
        self, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None
    ):
        super().__init__()
        # if bits not in [4]:
        #     raise NotImplementedError("Only 4 bits are supported.")
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.in_features = in_features
        self.out_features = out_features
        self.w_shape = weight.shape

        self.qconfig = get_config(config)
        model_qmetadata = self.get_qmetadata()
        self.bits = model_qmetadata["b"]
        self.groupsize = model_qmetadata["gp_sz"]
        if weight.flatten().shape[0] % self.groupsize != 0:
            self.num_groups = weight.flatten().shape[0] // self.groupsize + 1
        else:
            self.num_groups = weight.flatten().shape[0] // self.groupsize
        # self.maxq = 2 ** self.bits - 1

        if self.bits == 2:
            self.qmap = create_normal_map(offset=0.91, total_bits=self.bits) # 0.995, 0.8, 0.9, 0.92, 0.91, 0.89
        else:
            self.qmap = create_normal_map(offset=0.98, total_bits=self.bits) # 0.995, 0.8, 0.9677083
        self.q_scales = Parameter(torch.empty(self.num_groups, **self.factory_kwargs)) # ones
        self.q_biases = Parameter(torch.empty(self.num_groups, **self.factory_kwargs)) # zeros
        if is_cuda:
            self.qmap = self.qmap.to(device)
            # self.q_scales = self.q_scales.to(device)
            # self.q_biases = self.q_biases.to(device)

        # quantize bias as well?
        self.register_buffer('q_weight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
        '''
        self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
        self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
        self.register_buffer('g_idx', torch.tensor([i // self.groupsize  for i in range(in_features)], dtype = torch.int32))
        '''
        if bias is not None:
            self.register_buffer('bias', bias) # torch.zeros((out_features),dtype=torch.float16)
        else:
            self.bias = None
        
        '''
        # is performed by unpacking the weights and using torch.matmul
        if self.bits in [2,4,8]: 
            self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
        elif self.bits == 3:
            self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
                                                     [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
                                                     [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)
        '''  
        self.is_cuda = is_cuda
        self.device = device

        self.reset_parameters()
        self.q_weight.requires_grad = False
        if bias is not None:
            self.bias.requires_grad = False

        self.q_metadata = dict()
        self.init_q_weight(weight, self.factory_kwargs)
        # model_qmetadata = self.get_qmetadata()
        self.q_metadata.update(model_qmetadata)

        # self.q_weight.requires_grad = False
        # if bias is not None:
        #     self.bias.requires_grad = False

    def reset_parameters(self):
        init.ones_(self.q_scales)
        init.zeros_(self.q_biases)

    def init_q_weight(self, weight, factory_kwargs):
        model_qmetadata = self.get_qmetadata()
        self.q_weight, gen = vectorwise_quant(weight, torch.ones(self.num_groups, **factory_kwargs), torch.zeros(self.num_groups, **factory_kwargs), qmap=self.qmap, shape=self.w_shape, **model_qmetadata) 
        self.q_metadata.update(gen)

    def get_subqconfig(self):
        return self.qconfig.QUANT.M
        
    def get_qmetadata(self):
        subconfig = self.get_subqconfig()
        md = dict(
            b=subconfig.BITS,
            scale_type=subconfig.SCALE_TYPE.DEFAULT,
            quant_type=subconfig.QUANT_TYPE.DEFAULT,
            round_type=subconfig.ROUND_TYPE,
            gp_sz=subconfig.GROUP_SIZE,
            signed=subconfig.SIGNED,
        )
        return md

    def forward(self, x): # implement forward function here without Autograd4bit and compare how it is different
        if self.bits in [2, 3, 4]:
            dequant_weight = vectorwise_dequant(self.q_weight, self.q_scales, self.q_biases, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            out = F.linear(x, dequant_weight, self.bias)
        else:
            raise NotImplementedError()
        return out

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'



class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights



class TQLoRALinear(TQLinear, LoRALayer):
    def __init__(
        self, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        TQLinear.__init__(self, weight, bias, in_features, out_features, config, is_cuda, device, dtype)
        LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = Parameter(torch.empty(r, in_features, **self.factory_kwargs))
            self.lora_B = Parameter(torch.empty(out_features, r, **self.factory_kwargs))
            # self.lora_A = Parameter(self.q_scales.new_zeros((r, in_features)))
            # self.lora_B = Parameter(self.q_scales.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            
        self.reset_lora_parameters()
        if not q_trainable:
            self.q_scales.requires_grad = False
            self.q_biases.requires_grad = False

    def reset_lora_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            init.zeros_(self.lora_B)

    def forward(self, x):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.bits in [2, 3, 4]:
            dequant_weight = vectorwise_dequant(self.q_weight, self.q_scales, self.q_biases, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
        else:
            raise NotImplementedError()
        if self.r > 0 and not self.merged:
            # lora_interaction = F.linear(self.lora_dropout(x), self.lora_A) @ self.lora_B.transpose(0, 1)
            # result = F.linear(x, dequant_weight, self.bias) + lora_interaction * self.scaling
            result = F.linear(x, T(dequant_weight), bias=self.bias)  
            device = x.device
            lora_A, lora_B = self.lora_A.to(device), self.lora_B.to(device)         
            result += (self.lora_dropout(x) @ lora_A.transpose(0, 1) @ lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(dequant_weight), bias=self.bias)