import math
import transformers
import torch
import utils
import hadamard_utils
import fast_hadamard_transform
import torch.nn.functional as F
import torch.nn as nn
import os

# If SPIKE_ON is enabled, the process requires hardware optimized for spike operations to work.
# This is only for result correctness validation.
# os.environ['SPIKE_ON'] = '1'

os.environ['SPIKE_ON'] = '0'

def get_minq_maxq(bits, sym):
    if sym:
        maxq = torch.tensor(2**(bits-1)-1)
        minq = -maxq -1
    else:
        maxq = torch.tensor(2**bits - 1)
        minq = 0
    return minq, maxq

def asym_quant(x, scale, zero, maxq):
    scale = scale.to(x.device)
    zero = zero.to(x.device)
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return q, scale, zero

def asym_dequant(q, scale, zero):
    return scale * (q - zero)

def asym_quant_dequant(x, scale, zero, maxq):
    return asym_dequant(*asym_quant(x, scale, zero, maxq))

from Int2Spike.neuron import spike_fake_quant, SpikeCountTernaryLIFNode

def ssym_quant_dequant(x, scale, zero, maxq):
    """
    Perform spike quantization and dequantization, converting floating-point values to ternary spikes (-1/0/1).
    The zero will be automatically adjusted within the spike fake quantization process to ensure the 
    symmetric quantization and bidirectional encoding work properly.
    """
    q, scale, zero = asym_quant(x, scale, zero, maxq)
    if os.environ['SPIKE_ON'] == '1':
        q = spike_fake_quant(x = q, lif_quantizer = SpikeCountTernaryLIFNode(), x_zero = zero)
        # print(f'firing_rate: {lif_quantizer.firing_rate()}') or print(f'firing_rate: {abs(q).sum()/(len(q.reshape(-1))*abs(q).max())}') 

    return asym_dequant(q, scale, zero)  

# When spike quantization is enabled, the zero will be automatically adjusted.
def ssym_quant(x, scale, zero, maxq):
    return asym_quant(x, scale, zero, maxq)

def sym_quant(x, scale, maxq):
    scale = scale.to(x.device)
    q = torch.clamp(torch.round(x / scale), -(maxq+1), maxq)
    return q, scale

def sym_dequant(q, scale):
    return scale * q

def sym_quant_dequant(x, scale, maxq):
    return sym_dequant(*sym_quant(x, scale, maxq))


def two_compl(x, bits: int):
    return torch.where(x < 0, 2 ** bits + x, x)

# Pack the int tensor. Each uint8 stores two int4 value.
def pack_i4(q):
    assert torch.is_signed(q), 'The tensor to be packed should be signed int'
    minq, maxq = get_minq_maxq(4, True)
    assert torch.all(torch.logical_and(q >= minq, q <= maxq))

    q_i8 = two_compl(q.to(dtype=torch.int8), 4).to(torch.uint8)
    q_i4 = q_i8[:, 0::2] | (q_i8[:, 1::2] << 4)
    return q_i4

# Unpack the quantized int4 tensor (stored in uint8) into int32 tensor.
def unpack_i4(x: torch.Tensor):
    assert x.dtype == torch.uint8, 'The tensor to be unpacked should be stored in uint8'

    out_shape = list(x.shape)
    out_shape[-1] *= 2  # Each uint8 packs two numbers

    # Low 4 bits
    x0 = (x & 0x0f).to(torch.int8)
    x0[x0>=8] -= 16
    x0 = x0.view(-1, x0.shape[-1])

    # High 4 bits
    x1 = ((x & 0xf0) >> 4).to(torch.int8)
    x1[x1>=8] -= 16
    x1 = x1.view(-1, x1.shape[-1])

    out = torch.empty(out_shape, device=x.device, dtype=torch.int32)
    out = out.view(-1, out.shape[-1])
    # Interleaving
    out[:, 0::2] = x0
    out[:, 1::2] = x1

    return out.view(out_shape)

class SpikeQuantizer(torch.nn.Module):
    '''
        A class for quantizing the spike. per-token quantization for the activations.
    '''
    def __init__(self, name=''):
        super(SpikeQuantizer, self).__init__()
        self.register_buffer('maxq', torch.tensor(0))
        self.register_buffer('scale', torch.zeros(1))
        self.register_buffer('zero', torch.zeros(1))
        self.bits = 16
        self.name = name

    def free(self):
        self.zero = None
        self.scale = None

    def forward(self, x):
        x_dtype = x.dtype
        if self.bits == 16:
            return x
        return ssym_quant_dequant(x, self.scale, self.zero, self.maxq).to(x_dtype)

    # Different from `forward`, this method returns quantized integers, scales (and zeros if asymmetric).
    def quantize(self, x):
        return ssym_quant(x, self.scale, self.zero, self.maxq)

    def configure(self, bits, groupsize=-1, sym=False, clip_ratio=1.0):
        _, self.maxq = get_minq_maxq(bits, sym)
        self.bits = bits
        self.groupsize = groupsize
        self.sym = sym
        self.clip_ratio = clip_ratio
        assert self.clip_ratio <= 1 and self.clip_ratio > 0, 'Clip ratio should be in (0, 1]'

    def find_params_per_token_groupwise(self, x):
        init_shape = x.shape
        reshaped_x = x.reshape(-1, x.shape[-2], x.shape[-1] // self.groupsize, self.groupsize)

        xmax = torch.amax(reshaped_x, dim=3, keepdim=True) * self.clip_ratio
        xmin = torch.amin(reshaped_x, dim=3, keepdim=True) * self.clip_ratio
        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmax == 0
            self.scale = xmax / self.maxq
            self.scale[tmp] = 1
            self.zero = torch.zeros_like(self.scale)
        else:
            tmp = (xmin == 0) & (xmax == 0)
            xmin[tmp] = -1
            xmax[tmp] = +1
            self.scale = (xmax - xmin) / self.maxq
            self.zero = torch.round(-xmin / self.scale)

        self.scale = self.scale.repeat(1, 1, 1, self.groupsize).reshape(init_shape)
        self.zero = self.zero.repeat(1, 1, 1, self.groupsize).reshape(init_shape)

    def find_params(self, x):
        if self.bits == 16:
            return

        dev = x.device
        self.maxq = self.maxq.to(dev)

        init_shape = x.shape

        if self.groupsize > 0:
            # group-wise per-token quantization
            self.find_params_per_token_groupwise(x)
            utils.cleanup_memory(verbos=False)
            return

        reshaped_x = x.reshape((-1, x.shape[-1]))

        tmp = torch.zeros(reshaped_x.shape[0], device=dev)
        xmin = torch.minimum(reshaped_x.min(1)[0], tmp) * self.clip_ratio
        xmax = torch.maximum(reshaped_x.max(1)[0], tmp) * self.clip_ratio
        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmax == 0
            self.scale = (xmax / self.maxq).unsqueeze(1).repeat(1, reshaped_x.shape[-1])
            self.scale[tmp] = 1
            self.scale = self.scale.reshape(init_shape)
            self.zero = torch.zeros_like(self.scale)
        else:
            tmp = (xmin == 0) & (xmax == 0)
            xmin[tmp] = -1
            xmax[tmp] = +1
            self.scale = (xmax - xmin) / self.maxq
            self.zero = torch.round(-xmin / self.scale)

            self.scale = self.scale.unsqueeze(1).repeat(1, reshaped_x.shape[-1]).reshape(init_shape)
            self.zero = self.zero.unsqueeze(1).repeat(1, reshaped_x.shape[-1]).reshape(init_shape)

class ActQuantWrapper(torch.nn.Module):
    def __init__(self, module:torch.nn.Linear):
        super(ActQuantWrapper, self).__init__()
        assert isinstance(module, torch.nn.Linear)
        self.module = module
        self.weight = module.weight
        self.bias = module.bias
        self.register_buffer('had_K', torch.tensor(0))
        self._buffers['had_K'] = None
        self.K = 1
        self.online_full_had = False
        self.online_partial_had = False
        self.had_dim = 0
        self.fp32_had = False
        self.rotate_mode = None
        self.out_rotate_mode = None
        self.register_buffer('Q', torch.tensor(0))
        self.train_mode = False
        self.sparse_mode = False
        self.quantizer = None
        self.out_quantizer = None
        self.name = None
        self.sparse_func = None
        self.online_partial_had=False
        self.model_type='llama'

    def init_quantizer(self):
        self.quantizer = SpikeQuantizer(self.name)
        self.out_quantizer = SpikeQuantizer(self.name)

    def rotate_weight(self):
        assert self.train_mode == False
        rotate_mode = self.rotate_mode

        if rotate_mode == 'Q':
            self.weight.data = self.weight.data @ self.Q.to(self.weight.dtype)

        elif rotate_mode == 'H':
            self.weight.data = hadamard_utils.matmul_hadU_cuda(self.weight.data, self.had_K, self.K)

        elif rotate_mode == 'Head_in':
            init_shape = self.weight.shape     

            if self.K == 1:
                self.weight.data = fast_hadamard_transform.hadamard_transform(self.weight.data.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim).transpose(1, 2),
                                                            scale=1/math.sqrt(init_shape[-1]//self.had_dim)).transpose(1, 2)
            else:
                self.weight.data = (self.had_K.to(self.weight.data.dtype).to(self.weight.data.device) @ self.weight.data.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim)) / math.sqrt(init_shape[-1]//self.had_dim)
                                                
            self.weight.data = self.weight.data.reshape(init_shape)
    
    def forward(self, x):
        x_dtype = x.dtype

        if self.rotate_mode == 'Q':
            if self.model_type=='llama' and ('q_proj' in self.name or 'k_proj' in self.name or 'v_proj' in self.name or \
                'up_proj' in self.name or 'gate_proj' in self.name):
                pass
            else:
                x = x @ self.Q.to(x.dtype) 

                if self.sparse_mode:
                    x = self.sparse_func(x)

                if self.train_mode:
                    x = x @ self.Q.T.to(x.dtype)
                
        elif self.rotate_mode == 'H':
            x = hadamard_utils.matmul_hadU_cuda(x, self.had_K, self.K)
            if self.train_mode:
                x = hadamard_utils.matmul_hadUt_cuda(x, self.had_K, self.K)
                
        elif self.rotate_mode == 'Head_in' :
            init_shape = x.shape
            
            had_K, K = hadamard_utils.get_hadK(init_shape[-1])
            x = hadamard_utils.matmul_hadU_cuda(x, had_K, K)

            if self.sparse_mode:
                    x = self.sparse_func(x)
                    
            if self.train_mode:
                if self.K == 1:
                    x = fast_hadamard_transform.hadamard_transform(x.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim).transpose(1, 2),
                                                                    scale=1/math.sqrt(init_shape[-1]//self.had_dim)).transpose(1, 2)
                else:
                    x = (self.had_K.to(x.dtype).T @ x.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim)) / math.sqrt(init_shape[-1]//self.had_dim)
                    
                x = x.reshape(init_shape)
            else:
                x = x.reshape(init_shape)    

        if self.quantizer != None and self.quantizer.bits < 16: #Quantize, if needed
            assert self.train_mode == False
            self.quantizer.find_params(x)
            x = self.quantizer(x).to(x_dtype)
            self.quantizer.free()
            
        x = F.linear(x, self.weight, self.bias).to(x_dtype)
    
        rotate_mode = self.out_rotate_mode
        
        if rotate_mode == 'Head_out':
            init_shape = x.shape

            x = fast_hadamard_transform.hadamard_transform(x.reshape(-1, init_shape[-1]//self.had_dim, self.had_dim), 
                scale=1/math.sqrt(self.had_dim)).reshape(init_shape)
            
        if self.out_quantizer != None and self.out_quantizer.bits < 16: #Quantize the output, if needed
            self.out_quantizer.find_params(x)
            x = self.out_quantizer(x).to(x_dtype)
            self.out_quantizer.free()

        return x
    
class WeightQuantizer(torch.nn.Module):
    '''From GPTQ Repo'''

    def __init__(self, shape=1):
        super(WeightQuantizer, self).__init__()
        self.register_buffer('maxq', torch.tensor(0))
        self.register_buffer('scale', torch.zeros(shape))
        self.register_buffer('zero', torch.zeros(shape))

    def configure(
        self,
        bits, perchannel=False, sym=True,
        mse=False, norm=2.4, grid=100, maxshrink=.8,
    ):
        self.bits = bits
        self.perchannel = perchannel
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.grid = grid
        self.maxshrink = maxshrink
        if sym:
            self.maxq = torch.tensor(2**(bits-1)-1)
        else:
            self.maxq = torch.tensor(2**bits - 1)

    def find_params(self, x):
        if self.bits == 16:
            return
        dev = x.device
        self.maxq = self.maxq.to(dev)

        shape = x.shape
        if self.perchannel:
            x = x.flatten(1)
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax).clamp(min=1e-5)
            self.scale = xmax / self.maxq
            self.zero = torch.zeros_like(self.scale)
        else:
            tmp = (xmin == 0) & (xmax == 0)
            xmin[tmp] = -1
            xmax[tmp] = +1
            self.scale = (xmax - xmin).clamp(min=1e-5) / self.maxq
            self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            best = torch.full([x.shape[0]], float('inf'), device=dev)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid
                xmin1 = p * xmin
                xmax1 = p * xmax

                if self.sym:
                    scale1 = xmax1 / self.maxq
                    zero1 = torch.zeros_like(scale1)
                    q = sym_quant_dequant(x, scale1.unsqueeze(1), self.maxq)
                else:

                    scale1 = (xmax1 - xmin1) / self.maxq
                    zero1 = torch.round(-xmin1 / scale1)
                    q = asym_quant_dequant(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)

                q -= x
                q.abs_()
                q.pow_(self.norm)
                err = torch.sum(q, 1)
                tmp = err < best
                if torch.any(tmp):
                    best[tmp] = err[tmp]
                    self.scale[tmp] = scale1[tmp]
                    self.zero[tmp] = zero1[tmp]
        if not self.perchannel:

            tmp = shape[0]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        shape = [-1] + [1] * (len(shape) - 1)
        self.scale = self.scale.reshape(shape)
        self.zero = self.zero.reshape(shape)
        return

    # TODO: This should be better refactored into `forward`, which applies quantize and dequantize. A new method `quantize` should be added (if needed) to return the quantized integers and scales, like in ActQuantizer.
    def quantize(self, x):
        x_dtype = x.dtype
        if self.ready() and self.bits < 16:
            if self.sym:
                return sym_quant_dequant(x, self.scale, self.maxq).to(x_dtype)
            return asym_quant_dequant(x, self.scale, self.zero, self.maxq).to(x_dtype)
        return x

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)

def add_actquant(module, name='', layers=[torch.nn.Linear,
                                          ActQuantWrapper,
                                          transformers.models.falcon.modeling_falcon.FalconLinear]):
    if isinstance(module, ActQuantWrapper):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        if type(tmp) in layers:
            setattr(module, attr, ActQuantWrapper(tmp))
        if type(tmp) == torch.nn.Sequential:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActQuantWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.Sequential(*replaced))
        if type(tmp) == torch.nn.ModuleList:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActQuantWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.ModuleList(replaced))
    for name1, child in module.named_children():
        add_actquant(child, name + '.' + name1 if name != '' else name1, layers)

def find_qlayers(module, layers=[torch.nn.Linear,
                                ActQuantWrapper], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_qlayers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

from transformers.models.llama.modeling_llama import LlamaRMSNorm
class ActNormWrapper(torch.nn.Module):
    def __init__(self, module:LlamaRMSNorm):
        super(ActNormWrapper, self).__init__()
        assert isinstance(module, LlamaRMSNorm)
        self.module = module
        self.weight = module.weight
        self.variance_epsilon = module.variance_epsilon 

        self.register_buffer('had_K', torch.tensor(0))
        self._buffers['had_K'] = None
        self.K = 1
        self.online_full_had = False
        self.online_partial_had = False
        self.had_dim = 0
        self.fp32_had = False
        self.rotate_mode = None
        self.register_buffer('Q', torch.tensor(0))
        self.train_mode = False
        self.sparse_mode = False
        self.sparse_func = None
      
    def forward(self, x):
        x_dtype=x.dtype

        x = self.module(x).to(x_dtype)
        rotate_mode = self.rotate_mode

        if rotate_mode == 'Q':
            x = x @ self.Q.to(x.dtype) 

            if self.sparse_mode:
                x = self.sparse_func(x)

            if self.train_mode:
                x = x @ self.Q.T.to(x.dtype)
        return x

def add_norm_wrapper(module, name='', layers=[LlamaRMSNorm,
                                          ActNormWrapper]):
    if isinstance(module, ActNormWrapper):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        if type(tmp) in layers:
            setattr(module, attr, ActNormWrapper(tmp))
        if type(tmp) == torch.nn.Sequential:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActNormWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.Sequential(*replaced))
        if type(tmp) == torch.nn.ModuleList:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActNormWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.ModuleList(replaced))
    for name1, child in module.named_children():
        add_norm_wrapper(child, name + '.' + name1 if name != '' else name1, layers)


def find_norm(module, layers=[LlamaRMSNorm,
                              ActNormWrapper], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_norm(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

import torch
class QuantileShiftedReLU(torch.nn.Module):
    def __init__(self, q = 0.5, init_mode=True):
        super(QuantileShiftedReLU, self).__init__()
        self.q = q
        self.init_mode = init_mode
        self.register_buffer('x_quantile', torch.zeros(1))

    def forward(self, x):
        x_dtype = x.dtype
        if self.init_mode:
            x_ = x.detach().reshape(-1).float()
            x_quantile = torch.quantile(x_, q=self.q, dim=0, keepdim=True)
            x_quantile = x_quantile.to(x_dtype)
            x = torch.relu(x - x_quantile)
            self.x_quantile.data = x_quantile
            self.init_mode = False
        else:
            if self.q == 0.6 and self.x_quantile == torch.zeros(1).to(self.x_quantile.device):
                self.x_quantile = torch.tensor([0.0161], dtype=x_dtype).to(x.device)
            x = torch.relu(x - self.x_quantile.to(x.device))
        return x
