import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from typing import Optional, List

import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, Linear4bit

from .functional_bnb import quantize_4bit
from .quant_utils import vectorwise_quant, vectorwise_dequant, create_normal_map
from lpmm.config import get_config

import wandb


# Assumes layer is perfectly divisible into 256 * 256 blocks
class TQLinear(nn.Module): # TQ -> Trainable Quantization
    def __init__(
        self, name, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None
    ):
        super().__init__()
        # if bits not in [4]:
        #     raise NotImplementedError("Only 4 bits are supported.")
        self.name = name
        self.step = 0
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.in_features = in_features
        self.out_features = out_features
        self.w_shape = weight.shape

        self.qconfig = get_config(config)
        model_qmetadata = self.get_qmetadata()
        self.bits = model_qmetadata["b"]
        self.groupsize = model_qmetadata["gp_sz"]
        if weight.flatten().shape[0] % self.groupsize != 0:
            self.num_groups = weight.flatten().shape[0] // self.groupsize + 1
        else:
            self.num_groups = weight.flatten().shape[0] // self.groupsize
        # self.maxq = 2 ** self.bits - 1

        if self.bits == 2:
            self.qmap = create_normal_map(offset=0.95, total_bits=self.bits) # 0.93
        elif self.bits == 3:
            self.qmap = create_normal_map(offset=0.98, total_bits=self.bits) # 0.98
        else:
            self.qmap = create_normal_map(total_bits=self.bits) # 0.995, 0.8, 0.9677083, 0.98
        self.q_scales = Parameter(torch.empty(self.num_groups, **self.factory_kwargs)) # ones
        self.q_biases = Parameter(torch.empty(self.num_groups, **self.factory_kwargs)) # zeros
        if is_cuda:
            self.qmap = self.qmap.to(device)
            # self.q_scales = self.q_scales.to(device)
            # self.q_biases = self.q_biases.to(device)

        # quantize bias as well?
        self.register_buffer('q_weight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
        '''
        self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
        self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
        self.register_buffer('g_idx', torch.tensor([i // self.groupsize  for i in range(in_features)], dtype = torch.int32))
        '''
        if bias is not None:
            self.register_buffer('bias', bias) # torch.zeros((out_features),dtype=torch.float16)
        else:
            self.bias = None
        
        '''
        # is performed by unpacking the weights and using torch.matmul
        if self.bits in [2,4,8]: 
            self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
        elif self.bits == 3:
            self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
                                                     [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
                                                     [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)
        '''  
        self.is_cuda = is_cuda
        self.device = device

        self.reset_parameters()
        self.q_weight.requires_grad = False
        if bias is not None:
            self.bias.requires_grad = False

        self.q_metadata = dict()
        self.init_q_weight(weight, self.factory_kwargs)
        # model_qmetadata = self.get_qmetadata()
        self.q_metadata.update(model_qmetadata)

        # self.q_weight.requires_grad = False
        # if bias is not None:
        #     self.bias.requires_grad = False

    def reset_parameters(self):
        init.ones_(self.q_scales)
        init.zeros_(self.q_biases)

    def init_q_weight(self, weight, factory_kwargs):
        model_qmetadata = self.get_qmetadata()
        self.q_weight, gen = vectorwise_quant(self.name, weight, torch.ones(self.num_groups, **factory_kwargs), torch.zeros(self.num_groups, **factory_kwargs), qmap=self.qmap, shape=self.w_shape, **model_qmetadata) 
        self.q_metadata.update(gen)

    def get_subqconfig(self):
        return self.qconfig.QUANT.M
        
    def get_qmetadata(self):
        subconfig = self.get_subqconfig()
        md = dict(
            b=subconfig.BITS,
            scale_type=subconfig.SCALE_TYPE.DEFAULT,
            quant_type=subconfig.QUANT_TYPE.DEFAULT,
            round_type=subconfig.ROUND_TYPE,
            gp_sz=subconfig.GROUP_SIZE,
            signed=subconfig.SIGNED,
        )
        return md

    def forward(self, x): # implement forward function here without Autograd4bit and compare how it is different
        if self.bits in [2, 3, 4]:
            if self.name == 'model.layers.0.self_attn.k_proj':
                if self.step % 8 == 0:
                    # print(self.q_biases)
                    wandb.log({"q_bias": self.q_biases[0]})
            dequant_weight = vectorwise_dequant(self.q_weight, self.q_scales, self.q_biases, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            out = F.linear(x, dequant_weight, self.bias)
        else:
            raise NotImplementedError()
        self.step += 1
        return out

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'



class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights



class TQLoRALinear(TQLinear, LoRALayer):
    def __init__(
        self, name, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        TQLinear.__init__(self, name, weight, bias, in_features, out_features, config, is_cuda, device, dtype)
        LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = Parameter(torch.empty(r, in_features, **self.factory_kwargs))
            self.lora_B = Parameter(torch.empty(out_features, r, **self.factory_kwargs))
            # self.lora_A = Parameter(self.q_scales.new_zeros((r, in_features)))
            # self.lora_B = Parameter(self.q_scales.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            
        self.reset_lora_parameters()
        if not q_trainable:
            self.q_scales.requires_grad = False
            self.q_biases.requires_grad = False

    def reset_lora_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            init.zeros_(self.lora_B)

    def forward(self, x):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.bits in [2, 3, 4]:
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(self.q_metadata['offsets'])
            dequant_weight = vectorwise_dequant(self.q_weight, self.q_scales, self.q_biases, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(dequant_weight)
        else:
            raise NotImplementedError()
        if self.r > 0 and not self.merged:
            # lora_interaction = F.linear(self.lora_dropout(x), self.lora_A) @ self.lora_B.transpose(0, 1)
            # result = F.linear(x, dequant_weight, self.bias) + lora_interaction * self.scaling
            result = F.linear(x, T(dequant_weight), bias=self.bias)  
            device = x.device
            lora_A, lora_B = self.lora_A.to(device), self.lora_B.to(device)         
            result += (self.lora_dropout(x) @ lora_A.transpose(0, 1) @ lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(dequant_weight), bias=self.bias)



class TQParams4bit(Params4bit):
    def __new__(cls, data=None, requires_grad=False, quant_state=None, blocksize=64, compress_statistics=True, quant_type="fp4", quant_storage=torch.uint8, module=None, bnb_quantized=False):
        # Call the super class __new__ method
        self = super().__new__(cls, data, requires_grad, quant_state, blocksize, compress_statistics, quant_type, quant_storage, module, bnb_quantized)
        # Add any new properties specific to this subclass
        return self

    # You can override existing methods as well
    def _quantize(self, device):
        w = self.data.contiguous().cuda(device)
        w_4bit, quant_state = quantize_4bit(
            w,
            blocksize=self.blocksize,
            compress_statistics=self.compress_statistics,
            quant_type=self.quant_type,
            quant_storage=self.quant_storage,
        )
        self.data = w_4bit
        self.quant_state = quant_state
        if self.module is not None:
            self.module.quant_state = quant_state
        self.bnb_quantized = True
        return self

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if device is not None and device.type == "cuda" and not self.bnb_quantized:
            return self._quantize(device)
        else:
            if self.quant_state is not None:
                self.quant_state.to(device)

            new_param = TQParams4bit(
                super().to(device=device, dtype=dtype, non_blocking=non_blocking),
                requires_grad=self.requires_grad,
                quant_state=self.quant_state,
                blocksize=self.blocksize,
                compress_statistics=self.compress_statistics,
                quant_type=self.quant_type,
            )

            return new_param



class TQLinear4bit(nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        compute_dtype=None,
        compress_statistics=True,
        quant_type="fp4",
        quant_storage=torch.uint8,
        device=None,
        dtype=None
    ):
        """
        Initialize Linear4bit class.

        Args:
            input_features (`str`):
                Number of input features of the linear layer.
            output_features (`str`):
                Number of output features of the linear layer.
            bias (`bool`, defaults to `True`):
                Whether the linear class uses the bias term as well.
        """
        super().__init__(input_features, output_features, bias, device)
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.weight = TQParams4bit(
            self.weight.data,
            requires_grad=False,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
            quant_storage=quant_storage,
            module=self,
        )
        # self.persistent_buffers = []  # TODO consider as way to save quant state
        self.compute_dtype = compute_dtype
        self.compute_type_is_set = False
        self.quant_state = None
        self.quant_storage = quant_storage

        print(self.weight.blocksize)
        print(self.weight.shape)
        if self.weight.flatten().shape[0] % self.weight.block_size != 0:
            self.num_groups = self.weight.flatten().shape[0] // self.weight.blocksize + 1
        else:
            self.num_groups = self.weight.flatten().shape[0] // self.weight.blocksize
        self.q_biases = Parameter(torch.empty(self.num_groups, **self.factory_kwargs))

    def set_compute_type(self, x):
        if x.dtype in [torch.float32, torch.bfloat16]:
            # the input is in a dtype that is safe to compute in, we switch
            # to this type for speed and stability
            self.compute_dtype = x.dtype
        elif x.dtype == torch.float16:
            # we take the compoute dtype passed into the layer
            if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
                # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
                # warn the user about this
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
                )
                warnings.filterwarnings("ignore", message=".*inference.")
            if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
                )
                warnings.filterwarnings("ignore", message=".*inference or training")

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        """
        save weight and bias,
        then fill state_dict with components of quant_state
        """
        super()._save_to_state_dict(destination, prefix, keep_vars)  # saving weight and bias

        if getattr(self.weight, "quant_state", None) is not None:
            for k, v in self.weight.quant_state.as_dict(packed=True).items():
                destination[prefix + "weight." + k] = v if keep_vars else v.detach()

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            if getattr(self, "quant_state", None) is not None:
                # the quant state got lost when the parameter got converted. This happens for example for fsdp
                # since we registered the module, we can recover the state here
                assert self.weight.shape[1] == 1
                if not isinstance(self.weight, TQParams4bit):
                    self.weight = TQParams4bit(self.weight, quant_storage=self.quant_storage)
                self.weight.quant_state = self.quant_state
            else:
                print(
                    "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
                )
        if not self.compute_type_is_set:
            self.set_compute_type(x)
            self.compute_type_is_set = True

        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)

        out = out.to(inp_dtype)

        return out