import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from typing import Optional, List

import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, Linear4bit

from .functional_bnb import quantize_4bit
from .quant_utils import vectorwise_quant, vectorwise_dequant, create_normal_map
from lpmm.config import get_config

import wandb


# Assumes layer is perfectly divisible into 256 * 256 blocks
class TQLinear(nn.Module): # TQ -> Trainable Quantization
    def __init__(
        self, name, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None
    ):
        super().__init__()
        # if bits not in [4]:
        #     raise NotImplementedError("Only 4 bits are supported.")
        self.weight = weight
        self.name = name
        self.step = 0
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.in_features = in_features
        self.out_features = out_features
        self.w_shape = weight.shape

        self.qconfig = get_config(config)
        model_qmetadata = self.get_qmetadata()
        self.bits = model_qmetadata["b"]
        self.groupsize = model_qmetadata["gp_sz"]
        if weight.flatten().shape[0] % self.groupsize != 0:
            self.num_groups = weight.flatten().shape[0] // self.groupsize + 1
        else:
            self.num_groups = weight.flatten().shape[0] // self.groupsize
        # self.maxq = 2 ** self.bits - 1

        if self.bits == 2:
            self.qmap = create_normal_map(offset=0.9, total_bits=self.bits) # 0.93
        elif self.bits == 3:
            self.qmap = create_normal_map(offset=0.98, total_bits=self.bits) # 0.98
        else:
            self.qmap = create_normal_map(offset=0.99, total_bits=self.bits) # 0.995, 0.8, 0.9677083, 0.98
        if is_cuda:
            self.qmap = self.qmap.to(device)

        # quantize bias as well?
        self.register_buffer('q_weight', torch.zeros((in_features // 32 * self.bits, out_features), dtype=torch.int32))
        '''
        self.register_buffer('qzeros', torch.zeros((math.ceil(in_features / self.groupsize), out_features // 32 * self.bits), dtype=torch.int32))
        self.register_buffer('scales', torch.zeros((math.ceil(in_features / self.groupsize), out_features), dtype=torch.float16))
        self.register_buffer('g_idx', torch.tensor([i // self.groupsize  for i in range(in_features)], dtype = torch.int32))
        '''
        if bias is not None:
            # self.register_buffer('bias', bias) # torch.zeros((out_features),dtype=torch.float16)
            self.bias = Parameter(bias, requires_grad=True)
        else:
            self.bias = None
        
        '''
        # is performed by unpacking the weights and using torch.matmul
        if self.bits in [2,4,8]: 
            self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False)
        elif self.bits == 3:
            self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0],
                                                     [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31],
                                                     [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False)
        '''  
        self.is_cuda = is_cuda
        self.device = device

        self.q_weight.requires_grad = False

        self.q_metadata = dict()
        self.init_q_weight(weight, self.factory_kwargs)
        # model_qmetadata = self.get_qmetadata()
        self.q_metadata.update(model_qmetadata)

    def init_q_weight(self, weight, factory_kwargs):
        model_qmetadata = self.get_qmetadata()
        self.q_weight, gen = vectorwise_quant(self.name, weight, qmap=self.qmap, shape=self.w_shape, **model_qmetadata) 
        self.q_metadata.update(gen)

    def get_subqconfig(self):
        return self.qconfig.QUANT.M
        
    def get_qmetadata(self):
        subconfig = self.get_subqconfig()
        md = dict(
            b=subconfig.BITS,
            scale_type=subconfig.SCALE_TYPE.DEFAULT,
            quant_type=subconfig.QUANT_TYPE.DEFAULT,
            round_type=subconfig.ROUND_TYPE,
            gp_sz=subconfig.GROUP_SIZE,
            signed=subconfig.SIGNED,
        )
        return md

    def forward(self, x): # implement forward function here without Autograd4bit and compare how it is different
        if self.bits in [2, 3, 4]:
            if self.name == 'model.layers.0.self_attn.k_proj':
                if self.step % 8 == 0:
                    # print(self.q_biases)
                    wandb.log({"q_bias": self.q_biases[0]})
            dequant_weight = vectorwise_dequant(self.q_weight, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            out = F.linear(x, dequant_weight, self.bias)
        else:
            raise NotImplementedError()
        self.step += 1
        return out

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'



class QuantEmbedding(nn.Embedding): # TQ -> Trainable Quantization
    def __init__(
        self, name, weight, num_embeddings, embedding_dim, config, is_cuda=True, device=None, dtype=None
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, device=device, dtype=dtype)
        # if bits not in [4]:
        #     raise NotImplementedError("Only 4 bits are supported.")
        self.name = name
        self.step = 0
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.w_shape = weight.shape

        self.qconfig = get_config(config)
        model_qmetadata = self.get_qmetadata()
        self.bits = model_qmetadata["b"]
        self.groupsize = model_qmetadata["gp_sz"]
        if weight.flatten().shape[0] % self.groupsize != 0:
            self.num_groups = weight.flatten().shape[0] // self.groupsize + 1
        else:
            self.num_groups = weight.flatten().shape[0] // self.groupsize
        # self.maxq = 2 ** self.bits - 1

        if self.bits == 2:
            self.qmap = create_normal_map(offset=0.9, total_bits=self.bits) # 0.93
        elif self.bits == 3:
            self.qmap = create_normal_map(offset=0.98, total_bits=self.bits) # 0.98
        else:
            self.qmap = create_normal_map(offset=0.99, total_bits=self.bits) # 0.995, 0.8, 0.9677083, 0.98
        if is_cuda:
            self.qmap = self.qmap.to(device)

        # quantize bias as well?
        self.is_cuda = is_cuda
        self.device = device

        self.q_metadata = dict()
        self.q_metadata.update(model_qmetadata)
        # self.weight = Parameter(weight, requires_grad=False)
        self.init_q_weight(weight, self.factory_kwargs)
        # model_qmetadata = self.get_qmetadata()

    def init_q_weight(self, weight, factory_kwargs):
        q_weight, gen = vectorwise_quant(self.name, weight, qmap=self.qmap, shape=self.w_shape, **self.q_metadata) 
        self.q_metadata.update(gen)
        dequant_weight = vectorwise_dequant(q_weight, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
        self.weight = Parameter(dequant_weight, requires_grad=False)

    def get_subqconfig(self):
        return self.qconfig.QUANT.M
        
    def get_qmetadata(self):
        subconfig = self.get_subqconfig()
        md = dict(
            b=subconfig.BITS,
            scale_type=subconfig.SCALE_TYPE.DEFAULT,
            quant_type=subconfig.QUANT_TYPE.DEFAULT,
            round_type=subconfig.ROUND_TYPE,
            gp_sz=subconfig.GROUP_SIZE,
            signed=subconfig.SIGNED,
        )
        return md



class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights


'''
class TQLoRALinear(TQLinear, LoRALayer):
    def __init__(
        self, name, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        TQLinear.__init__(self, name, weight, bias, in_features, out_features, config, is_cuda, device, dtype)
        LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = Parameter(torch.empty(r, in_features, **self.factory_kwargs))
            self.lora_B = Parameter(torch.empty(out_features, r, **self.factory_kwargs))
            self.scaling = self.lora_alpha / self.r
            
        self.reset_lora_parameters()

    def reset_lora_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            init.zeros_(self.lora_B)

    def forward(self, x):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.bits in [2, 3, 4]:
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(self.q_metadata['offsets'])
            dequant_weight = vectorwise_dequant(self.q_weight, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(dequant_weight)
        else:
            raise NotImplementedError()
        if self.r > 0 and not self.merged:
            # lora_interaction = F.linear(self.lora_dropout(x), self.lora_A) @ self.lora_B.transpose(0, 1)
            # result = F.linear(x, dequant_weight, self.bias) + lora_interaction * self.scaling
            # result = F.linear(x, T(dequant_weight), bias=self.bias) 
            result = F.linear(x, self.weight, bias=self.bias) 
            device = x.device
            lora_A, lora_B = self.lora_A.to(device), self.lora_B.to(device)         
            result += (self.lora_dropout(x) @ lora_A.transpose(0, 1) @ lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(dequant_weight), bias=self.bias)
'''


class TQLoRALinear(TQLinear, LoRALayer):
    def __init__(
        self, name, weight, bias, in_features, out_features, config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        TQLinear.__init__(self, name, weight, bias, in_features, out_features, config, is_cuda, device, dtype)
        LoRALayer.__init__(self, r, lora_alpha, lora_dropout, merge_weights)
        
        self.quant = nn.Linear(in_features, out_features, bias=False)
        dequant_weight = vectorwise_dequant(self.q_weight, qmap=self.qmap, shape=self.w_shape, **self.q_metadata)
        self.quant.weight = Parameter(dequant_weight, requires_grad=False)
        # self.quant.weight = Parameter(self.weight, requires_grad=False)
        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        self.bias = Parameter(bias, requires_grad=True)
        if r > 0:
            self.lora_A = nn.Linear(in_features, r, bias=False)
            self.lora_B = nn.Linear(r, out_features, bias=False)
            
        self.reset_lora_parameters()

    def reset_lora_parameters(self):
        if hasattr(self, 'lora_A'):
            # initialize B the same way as the default for nn.Linear and A to zero
            # this is different than what is described in the paper but should not affect performance
            lora_A_weight = Parameter(self.quant.weight.new_zeros((self.r, self.in_features)),
                                        requires_grad=True)
            lora_B_weight = Parameter(self.quant.weight.new_zeros((self.out_features, self.r),
                                        requires_grad=True))
            nn.init.kaiming_uniform_(lora_A_weight, a=math.sqrt(5))
            nn.init.zeros_(lora_B_weight)
            self.lora_A.weight = lora_A_weight
            self.lora_B.weight = lora_B_weight

    def forward(self, x):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.bits in [2, 3, 4]:
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(self.q_metadata['offsets'])
            pass
            # if self.name == 'model.layers.0.self_attn.k_proj':
            #     print(dequant_weight)
        else:
            raise NotImplementedError()
        if self.r > 0 and not self.merged: 
            result = self.quant(x)
            result = result + self.bias
            device = x.device
            self.lora_A, self.lora_B = self.lora_A.to(device), self.lora_B.to(device)         
            lora_A_output = self.lora_A(x)
            result += self.lora_B(lora_A_output)
            return result
        else:
            return F.linear(x, T(dequant_weight), bias=self.bias)


'''
class LinearQuantEmbedding(QuantEmbedding, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        name,
        num_embeddings: int,
        embedding_dim: int,
        config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1.0, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        QuantEmbedding.__init__(self, name, num_embeddings, embedding_dim, config, is_cuda=True, device=None, dtype=None)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights)
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
        # self.reset_parameters()

    def reset_parameters(self):
        QuantEmbedding.reset_parameters(self)
        # if hasattr(self, 'lora_A'):
        #     # initialize A the same way as the default for nn.Linear and B to zero
        #     # nn.init.zeros_(self.lora_A)
        #     # nn.init.normal_(self.lora_B)
        #     nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        #     nn.init.zeros_(self.lora_B)
        
    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            result = QuantEmbedding.forward(self, x)
            device = x.device
            self.lora_A, self.lora_B = self.lora_A.to(device), self.lora_B.to(device)  
            after_A = F.embedding(
                x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
                self.norm_type, self.scale_grad_by_freq, self.sparse
            )
            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return QuantEmbedding.forward(self, x)
'''



class LinearQuantEmbedding(QuantEmbedding, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        name,
        weight,
        num_embeddings: int,
        embedding_dim: int,
        config, is_cuda=True, device=None, dtype=None, r=0, lora_alpha=1.0, lora_dropout=0., fan_in_fan_out=False, merge_weights=True, q_trainable=True
    ):
        QuantEmbedding.__init__(self, name, weight, num_embeddings, embedding_dim, config, is_cuda=True, device=None, dtype=None)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights)
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((num_embeddings, r)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((r, embedding_dim)))
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
        # self.reset_parameters()

    def reset_parameters(self):
        QuantEmbedding.reset_parameters(self)
        # if hasattr(self, 'lora_A'):
        #     # initialize A the same way as the default for nn.Linear and B to zero
        #     # nn.init.zeros_(self.lora_A)
        #     # nn.init.normal_(self.lora_B)
        #     nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        #     nn.init.zeros_(self.lora_B)
        
    def train(self, mode: bool = True):
        QuantEmbedding.train(self, mode)

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            result = QuantEmbedding.forward(self, x)
            device = x.device
            self.lora_A, self.lora_B = self.lora_A.to(device), self.lora_B.to(device)  
            after_A = F.embedding(
                x, self.lora_A
            )
            result += (after_A @ self.lora_B)
            return result
        else:
            return QuantEmbedding.forward(self, x)



class TQParams4bit(Params4bit):
    def __new__(cls, data=None, requires_grad=False, quant_state=None, blocksize=64, compress_statistics=True, quant_type="fp4", quant_storage=torch.uint8, module=None, bnb_quantized=False):
        # Call the super class __new__ method
        self = super().__new__(cls, data, requires_grad, quant_state, blocksize, compress_statistics, quant_type, quant_storage, module, bnb_quantized)
        # Add any new properties specific to this subclass
        return self

    # You can override existing methods as well
    def _quantize(self, device):
        w = self.data.contiguous().cuda(device)
        w_4bit, quant_state = quantize_4bit(
            w,
            blocksize=self.blocksize,
            compress_statistics=self.compress_statistics,
            quant_type=self.quant_type,
            quant_storage=self.quant_storage,
        )
        self.data = w_4bit
        self.quant_state = quant_state
        if self.module is not None:
            self.module.quant_state = quant_state
        self.bnb_quantized = True
        return self

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

        if device is not None and device.type == "cuda" and not self.bnb_quantized:
            return self._quantize(device)
        else:
            if self.quant_state is not None:
                self.quant_state.to(device)

            new_param = TQParams4bit(
                super().to(device=device, dtype=dtype, non_blocking=non_blocking),
                requires_grad=self.requires_grad,
                quant_state=self.quant_state,
                blocksize=self.blocksize,
                compress_statistics=self.compress_statistics,
                quant_type=self.quant_type,
            )

            return new_param



class TQLinear4bit(nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        compute_dtype=None,
        compress_statistics=True,
        quant_type="fp4",
        quant_storage=torch.uint8,
        device=None,
        dtype=None
    ):
        """
        Initialize Linear4bit class.

        Args:
            input_features (`str`):
                Number of input features of the linear layer.
            output_features (`str`):
                Number of output features of the linear layer.
            bias (`bool`, defaults to `True`):
                Whether the linear class uses the bias term as well.
        """
        super().__init__(input_features, output_features, bias, device)
        self.factory_kwargs = {'device': device, 'dtype': dtype}
        self.weight = TQParams4bit(
            self.weight.data,
            requires_grad=False,
            compress_statistics=compress_statistics,
            quant_type=quant_type,
            quant_storage=quant_storage,
            module=self,
        )
        # self.persistent_buffers = []  # TODO consider as way to save quant state
        self.compute_dtype = compute_dtype
        self.compute_type_is_set = False
        self.quant_state = None
        self.quant_storage = quant_storage

        print(self.weight.blocksize)
        print(self.weight.shape)
        if self.weight.flatten().shape[0] % self.weight.block_size != 0:
            self.num_groups = self.weight.flatten().shape[0] // self.weight.blocksize + 1
        else:
            self.num_groups = self.weight.flatten().shape[0] // self.weight.blocksize
        self.q_biases = Parameter(torch.empty(self.num_groups, **self.factory_kwargs))

    def set_compute_type(self, x):
        if x.dtype in [torch.float32, torch.bfloat16]:
            # the input is in a dtype that is safe to compute in, we switch
            # to this type for speed and stability
            self.compute_dtype = x.dtype
        elif x.dtype == torch.float16:
            # we take the compoute dtype passed into the layer
            if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
                # single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
                # warn the user about this
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
                )
                warnings.filterwarnings("ignore", message=".*inference.")
            if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
                warnings.warn(
                    "Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
                )
                warnings.filterwarnings("ignore", message=".*inference or training")

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        """
        save weight and bias,
        then fill state_dict with components of quant_state
        """
        super()._save_to_state_dict(destination, prefix, keep_vars)  # saving weight and bias

        if getattr(self.weight, "quant_state", None) is not None:
            for k, v in self.weight.quant_state.as_dict(packed=True).items():
                destination[prefix + "weight." + k] = v if keep_vars else v.detach()

    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            if getattr(self, "quant_state", None) is not None:
                # the quant state got lost when the parameter got converted. This happens for example for fsdp
                # since we registered the module, we can recover the state here
                assert self.weight.shape[1] == 1
                if not isinstance(self.weight, TQParams4bit):
                    self.weight = TQParams4bit(self.weight, quant_storage=self.quant_storage)
                self.weight.quant_state = self.quant_state
            else:
                print(
                    "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
                )
        if not self.compute_type_is_set:
            self.set_compute_type(x)
            self.compute_type_is_set = True

        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)

        out = out.to(inp_dtype)

        return out