from typing import List
import torch
from scipy.stats import norm
import itertools



lpmm_generator = None
FP_EXPONENT_BIS_MAP = {
    2: 1,
    3: 2,
    4: 2,
    5: 3,
    6: 3,
    7: 4,
    8: 4,
}

def init_lpmm_generator(gpu, seed):
    global lpmm_generator
    if lpmm_generator is None:
        lpmm_generator = torch.Generator(device=gpu)
        if seed is None:
            seed = 3407
        lpmm_generator.manual_seed(seed)


def vectorwise_quant(name, x, q_scales, q_biases, **kwargs):
    '''interface quantization function
    '''
    qx = x.detach() # x.detach() -> keep the reference of original tensor // should change this when we use changeable quantized values

    # save kwargs
    generated_metadata = {}
    generated_metadata['dtype'] = x.dtype
    generated_metadata['stride'] = x.stride()

    # Given a ill-conditioned/quantization-unfriendly tensor, how to normalize and/or avoid outlier?
    # scale/noramlize the original tensor
    qx, md = make_group_and_get_max(qx, **kwargs)
    generated_metadata.update(md)

    qmap, b = kwargs['qmap'], kwargs['b']
    device = qx.device  # Assuming qx is already on the desired device
    qmap, q_scales, q_biases = qmap.to(device), q_scales.to(device), q_biases.to(device)  # Move qmap to the same device as qx
    qmap = change_qmap(qmap, q_scales, q_biases, generated_metadata['max1'])
    qx = nonlinear_quant_grouped(name, qx, qmap, b, round_type=kwargs['round_type'], scale_type=kwargs['scale_type'])

    return qx, generated_metadata


def vectorwise_dequant(qx, q_scales, q_biases, denormalized=True, **kwargs):
    '''dequantization function
    '''
    x = qx # qx.detach()
    
    # load kwargs
    dtype = kwargs['dtype']
    stride = kwargs['stride']

    qmap, b = kwargs['qmap'], kwargs['b']
    device = x.device  # Assuming qx is already on the desired device
    qmap, q_scales, q_biases = qmap.to(device), q_scales.to(device), q_biases.to(device)  # Move qmap to the same device as qx
    qmap = change_qmap(qmap, q_scales, q_biases, kwargs['max1'])
    # print(q_biases)
    x = nonlinear_dequant(x, qmap, b, shape=kwargs['scaled_shape'], round_type=kwargs['round_type'], scale_type=kwargs['scale_type'])

    # only for debug
    if not denormalized:
        return x 

    # scale the dequantized tensor to get the original tensor
    # max1 = kwargs['max1']
    # x = x.mul(max1)
    shape = kwargs['shape']
    x = recon_grouped_tensor(x, shape)

    if x.stride() != stride:
        # print(f"[warn] in dequantization, approximator x has not same stride {x.stride()} as original stride {stride}."
        #        "Renew a tensor with same memory format.")
        recon_x = torch.empty_strided(x.shape, stride, dtype=dtype, layout=torch.strided, device=x.device)
        recon_x.copy_(x)
        del x
        return recon_x
    else:
        x = x.to(dtype=dtype)
        return x


def make_group_and_get_max(qx, **kwargs):
    generated_metadata = {}
    gp_sz = kwargs['gp_sz']
    qx = group_tensor(qx, gp_sz) # (num_gp, gp_sz)
    max1 = _max_reduce_except_dim(qx.abs(), 0)
    # qx = qx.div(max1)
    generated_metadata['max1'] = max1
    generated_metadata['scaled_shape'] = qx.shape

    return qx, generated_metadata

'''
def get_optimal_offset_for_groups(qx, b, grid_num=10, round_type='nearest'):
    num_groups = len(qx)
    offset_list = [0.9+0.01*i for i in range(grid_num)]
    offsets = torch.zeros(num_groups)
    for offset in offset_list:
        create_normal_map(offset=offset, total_bits=b)
'''

def change_qmap(qmap, q_scales, q_biases, max1):
    max1 = max1.to(qmap.device)
    qmap = qmap.unsqueeze(0)
    # print(max1.shape)
    # print(q_scales.unsqueeze(1).shape)
    scale_factors = max1 * q_scales.unsqueeze(1)
    # print(scale_factors.shape)
    qmap = qmap * scale_factors + q_biases.unsqueeze(1)
    # print(qmap.shape)

    negative_scales_mask = q_scales < 0
    qmap[negative_scales_mask] = qmap[negative_scales_mask].flip(1)

    # max_abs_vals = qmap.abs().max(dim=1, keepdim=True).values
    # new_qmap = qmap / max_abs_vals
    return qmap


def group_tensor(input: torch.Tensor, gp_sz: int):
    r"""Group tensor into subtensors of size 'gp_sz'
    """
    if not gp_sz > 0:
        raise ValueError("group size need to be a positive integer, but found {}".format(gp_sz))

    input_flatten = input.flatten()
    num_features = input_flatten.shape[0] 

    # Reshape the tensor into group
    if num_features % gp_sz != 0:
        # Padding
        new_num_features = (num_features // gp_sz + 1) * gp_sz
        delta = new_num_features - num_features
        input_flatten = torch.cat([input_flatten,
                                   torch.zeros([delta], dtype=input.dtype, device=input.device)], dim=0)

    input_groups = input_flatten.view(-1, gp_sz) # num_groups, group_size
    return input_groups


def recon_grouped_tensor(grouped_tensor: torch.Tensor, shape) -> torch.Tensor :
    r"""Reconstruction the tensor to original (or specific) shape
    """
    numel = shape.numel()
    recon_flatten = grouped_tensor.flatten()[:numel]
    recon = recon_flatten.view(shape)
    return recon


def _max_reduce_except_dim(tensor, dim):
    # Computes max along all dimensions except the given dim.
    # If tensor is a scalar, it returns tensor.
    rank = len(tensor.shape)
    result = tensor
    if rank > 0:
        assert dim < rank
        for d in range(rank):
            if d != dim:
                result = result.max(dim=d, keepdim=True).values
    return result

'''
def nonlinear_quant(qx, qmap, b, round_type='sr', scale_type='group'):
    qmaplen = len(qmap)
    qx.clamp_(qmap[0], qmap[-1])
    floor_idx = ((qx.unsqueeze(-1) >= qmap).sum(dim=-1) - 1).clamp_(0, qmaplen - 1)
    next_idx = (floor_idx + 1).clamp_max_(qmaplen - 1)
    Z = qmap[next_idx] - qmap[floor_idx]
    Z[Z <= 0] = 1.
    proba = (qx - qmap[floor_idx]) / Z
    proba = torch.bernoulli(proba, generator=lpmm_generator)
    idx = (floor_idx + proba).round_().to(torch.int)

    return idx
'''


def nonlinear_quant_grouped(name, qx, qmap, b, round_type='nearest', scale_type='group'):
    # device = qx.device  # Assuming qx is already on the desired device
    # qmap = qmap.to(device)  # Move qmap to the same device as qx
    # Assuming qx is reshaped or indexed to match groups in qmap if necessary
    num_groups, qmaplen = qmap.shape
    original_qx = qx.clone()  # Copy of original values for error calculation
    # Ensure qx is within the range of each group's qmap
    qx = qx.clamp(qmap[:, 0].unsqueeze(-1), qmap[:, -1].unsqueeze(-1))
    # Calculate floor indices for each group
    floor_idx = ((qx.unsqueeze(-1) >= qmap.unsqueeze(1)).sum(dim=-1) - 1).clamp_(0, qmaplen - 1)
    # Calculate next indices, ensuring it does not exceed qmaplen - 1
    next_idx = (floor_idx + 1).clamp_max_(qmaplen - 1)
    # Calculate Z for each group
    Z = torch.gather(qmap, 1, next_idx) - torch.gather(qmap, 1, floor_idx)
    Z[Z <= 0] = 1.  # Ensure no division by zero
    # Calculate proportional distance for stochastic rounding
    prop_dist = (qx - torch.gather(qmap, 1, floor_idx)) / Z
    if round_type == 'sr':
        # Apply stochastic rounding
        proba = torch.bernoulli(prop_dist)
        # Calculate final indices
        idx = (floor_idx + proba).round().to(torch.int8) # torch.int -> 32bit
        # print(idx.dtype)
    elif round_type == 'nearest':
        # Nearest rounding: round relative distances to the nearest integer (0 or 1)
        rounded_idx = (prop_dist >= 0.5).long()
        # Calculate final indices based on nearest rounding
        idx = floor_idx + rounded_idx
        idx = idx.round().to(torch.int8)  # Convert to int8 if desired

    # # Gather the quantized values back
    # quantized_values = torch.gather(qmap, 1, idx.to(torch.int64))

    # # Calculate the L2 norm for the quantization error
    # quantization_error = torch.norm(original_qx - quantized_values, p=2, dim=1)  # L2 norm along group_size dimension
    # if name == 'model.layers.0.self_attn.k_proj':
    #     print(quantization_error)
    #     print(torch.norm(quantization_error))

    return idx


def nonlinear_dequant(qx, qmap, b, shape, round_type='sr', scale_type='group'):
    # device = qx.device  # Assuming qx is already on the desired device
    # qmap = qmap.to(device)  # Move qmap to the same device as qx
    # Ensure qx indices are within the valid range
    num_groups, qmaplen = qmap.shape
    qx = qx.clamp(0, qmaplen - 1)
    
    # Reshape qx to match the expected output shape
    # Assuming qx is a flat index tensor; adjust shape if it's not
    # qx = qx.view(*shape)
    
    # Map indices back to quantization values
    # For each index in qx, find the corresponding value in qmap
    dequant_values = torch.gather(qmap, 1, qx.long().to(torch.int64))
    
    return dequant_values


def create_normal_map(offset=0.9677083, use_extra_value=True, use_adaptive_map=False, total_bits=4, mu=0.0, sigma=1.0, min_weight=None, max_weight=None): 
    if use_adaptive_map:
        if not min_weight or not max_weight:
            if mu >= 0.0:
                if norm.ppf(1-offset) >= 0.0:
                    # offset = 0.99999
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
                else:
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits-1), loc=mu, scale=sigma).tolist()
                    # v = v + [-v[2**(total_bits-1)-1]]
                    v = v + [0]
                    '''
                    v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1)+1)[:-1], loc=0.0, scale=sigma).tolist()
                    v2 = [0] # [0]
                    v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1], loc=0.0, scale=sigma).tolist()
                    v = v1 + v2 + v3
                    '''
                    
            else:
                if norm.ppf(offset) <= 0.0:
                    # offset = 0.99999
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
                else:
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits-1), loc=mu, scale=sigma).tolist()
                    # v = v + [-v[2**(total_bits-1)-1]]
                    v = v + [0]
                    '''
                    v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=0.0, scale=sigma).tolist()
                    v2 = [0] # [0]
                    v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1)+1)[:-1], loc=0.0, scale=sigma).tolist()
                    v = v1 + v2 + v3
                    '''
                    
        else:
            min_cdf, max_cdf = norm.cdf(min_weight, loc=mu, scale=sigma), norm.cdf(max_weight, loc=mu, scale=sigma)
            # print('min cdf is {} and max cdf is {}'.format(min_cdf, max_cdf))
            v = norm.ppf(torch.linspace(min_cdf, max_cdf, 2**total_bits), loc=mu, scale=sigma).tolist()
    else:
        
        if use_extra_value:
            # one more positive value, this is an asymmetric type
            v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1)+1)[:-1]).tolist()
            # v2 = [0]*(2**total_bits-15) ## we have 15 non-zero values in this data type
            v2 = [0] # [0]
            # v3 = (-norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=mu, scale=sigma)).tolist()
            v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1]).tolist()
        else:
            v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1]).tolist()
            # v2 = [0]*(2**total_bits-14) ## we have 14 non-zero values in this data type
            v2 = [0]*2 # [0]*2
            # v3 = (-norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=mu, scale=sigma)).tolist()
            v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1]).tolist()

        v = v1 + v2 + v3
        
        # v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
    values = torch.Tensor(v)
    values = values.sort().values
    
    values /= torch.abs(values).max()
    # v_995 = norm.ppf(torch.linspace(0.995, 1-0.995, 2**total_bits), loc=mu, scale=sigma)[0]
    # values /= torch.tensor(v_995)

    return values
