import os
from typing import List
import torch
from scipy.stats import norm
from torch.distributions.normal import Normal
import itertools
import matplotlib.pyplot as plt



lpmm_generator = None
FP_EXPONENT_BIS_MAP = {
    2: 1,
    3: 2,
    4: 2,
    5: 3,
    6: 3,
    7: 4,
    8: 4,
}

def init_lpmm_generator(gpu, seed):
    global lpmm_generator
    if lpmm_generator is None:
        lpmm_generator = torch.Generator(device=gpu)
        if seed is None:
            seed = 3407
        lpmm_generator.manual_seed(seed)


def vectorwise_quant(name, x, q_scales, q_biases, **kwargs):
    '''interface quantization function
    '''
    qx = x.detach() # x.detach() -> keep the reference of original tensor // should change this when we use changeable quantized values

    # save kwargs
    generated_metadata = {}
    generated_metadata['dtype'] = x.dtype
    generated_metadata['stride'] = x.stride()

    # Given a ill-conditioned/quantization-unfriendly tensor, how to normalize and/or avoid outlier?
    # scale/noramlize the original tensor
    qx, md = make_group_and_get_max(qx, **kwargs)
    generated_metadata.update(md)

    original_qx = qx.clone()
    row_index = None
    if name == 'model.layers.0.self_attn.k_proj':
        '''
        # Find the row index of the maximum element
        max_row_index = torch.argmax(original_qx, dim=0)[0]
        max_row_index = torch.argmax(original_qx.view(-1))

        # Convert the index of the flattened array back to 2D indices
        row_index, _ = divmod(max_row_index.item(), original_qx.size(1))
        '''
        # Define the k% (for example, 75% largest element)
        k_percent = 0.00005

        # Flatten the tensor
        qx_flat = original_qx.view(-1)
        qx_flat = qx_flat.float()

        # Calculate the k-th largest index
        k = int((1 - k_percent / 100) * qx_flat.numel())
        k_value, k_index = torch.kthvalue(qx_flat, k + 1)  # k+1 because kthvalue is 1-indexed

        # Convert the index of the flattened array back to 2D indices
        row_index, _ = divmod(k_index.item(), original_qx.size(1))

    qmap, b = kwargs['qmap'], kwargs['b']
    device = qx.device  # Assuming qx is already on the desired device
    qmap, q_scales, q_biases = qmap.to(device), q_scales.to(device), q_biases.to(device)  # Move qmap to the same device as qx
    offsets = get_optimal_offset_for_groups(name, qx, b, q_scales, q_biases, generated_metadata['max1'], round_type=kwargs['round_type'], row_index=row_index)
    generated_metadata['offsets'] = offsets
    qmap = change_qmap_with_offsets(offsets, b, q_scales, q_biases, generated_metadata['max1'])
    qx, _ = nonlinear_quant_grouped(name, qx, qmap, b, round_type=kwargs['round_type'], scale_type=kwargs['scale_type'])

    if name == 'model.layers.0.self_attn.k_proj':
        qmap_ref = create_normal_map(offset=offsets[row_index], total_bits=b, scaled=False)
        qmap_ref = change_qmap(qmap_ref, q_scales, q_biases, generated_metadata['max1'])
        plot_weight_distribution(original_qx, qmap_ref[row_index], qmap[row_index])

    return qx, generated_metadata


def vectorwise_dequant(name, qx, q_scales, q_biases, denormalized=True, **kwargs):
    '''dequantization function
    '''
    x = qx # qx.detach()
    
    # load kwargs
    dtype = kwargs['dtype']
    stride = kwargs['stride']

    qmap, b = kwargs['qmap'], kwargs['b']
    device = x.device  # Assuming qx is already on the desired device
    qmap, q_scales, q_biases = qmap.to(device), q_scales.to(device), q_biases.to(device)  # Move qmap to the same device as qx
    qmap = change_qmap_with_offsets(kwargs['offsets'], b, q_scales, q_biases, kwargs['max1'])
    # print(q_biases)
    x = nonlinear_dequant(name, x, qmap, b, shape=kwargs['scaled_shape'], round_type=kwargs['round_type'], scale_type=kwargs['scale_type'])

    # only for debug
    if not denormalized:
        return x 

    # scale the dequantized tensor to get the original tensor
    # max1 = kwargs['max1']
    # x = x.mul(max1)
    shape = kwargs['shape']
    x = recon_grouped_tensor(x, shape)

    if x.stride() != stride:
        # print(f"[warn] in dequantization, approximator x has not same stride {x.stride()} as original stride {stride}."
        #        "Renew a tensor with same memory format.")
        recon_x = torch.empty_strided(x.shape, stride, dtype=dtype, layout=torch.strided, device=x.device)
        recon_x.copy_(x)
        del x
        return recon_x
    else:
        x = x.to(dtype=dtype)
        return x


def make_group_and_get_max(qx, **kwargs):
    generated_metadata = {}
    gp_sz = kwargs['gp_sz']
    qx = group_tensor(qx, gp_sz) # (num_gp, gp_sz)
    max1 = _max_reduce_except_dim(qx.abs(), 0)
    # qx = qx.div(max1)
    generated_metadata['max1'] = max1
    generated_metadata['scaled_shape'] = qx.shape

    return qx, generated_metadata


def get_optimal_offset_for_groups(name, qx, b, q_scales, q_biases, max1, round_type='nearest', grid_num=10, grid_start=0.9, grid_end=1.0, row_index=None):
    num_groups = len(qx)
    grid_step = (grid_end - grid_start) / grid_num 
    offset_dict = {i:grid_end+grid_step*(i-grid_num) for i in range(grid_num)}
    # offset_dict = {i:1.0+0.01*(i-10) for i in range(10)}
    # offset_dict = {i: 0.9 + 0.01 * i for i in range(grid_num)}
    # offset_dict = {0:0.9, 1:0.91, 2:0.92, 3:0.93, 4:0.94, 5:0.95, 6:0.96, 7:0.97, 8:0.98, 9:0.99, 10:0.995, 11:0.996, 12:0.997, 13:0.998, 14:0.999, 15:0.9995}
    quantization_errors = torch.zeros(num_groups, grid_num)
    for idx, offset in offset_dict.items():
        qmap = create_normal_map(offset=offset, total_bits=b)
        qmap = change_qmap(qmap, q_scales, q_biases, max1)
        _, quantization_error = nonlinear_quant_grouped(name, qx, qmap, b, round_type=round_type)
        quantization_errors[:, idx] = quantization_error
    
    best_offset_indices = quantization_errors.argmin(dim=1)
    offsets = torch.tensor([offset_dict[idx.item()] for idx in best_offset_indices])
    if name == 'model.layers.0.self_attn.k_proj':
        print(quantization_errors[row_index])
        plot_offset_quantization_error(list(offset_dict.values()), quantization_errors[row_index])

    return offsets


def change_qmap(qmap, q_scales, q_biases, max1):
    qmap = qmap.to(q_scales.device)
    qmap = qmap.unsqueeze(0)
    # print(max1.shape)
    # print(q_scales.unsqueeze(1).shape)
    scale_factors = max1 * q_scales.unsqueeze(1)
    # print(scale_factors.shape)
    qmap = qmap * scale_factors + q_biases.unsqueeze(1)
    # print(qmap.shape)

    negative_scales_mask = q_scales < 0
    qmap[negative_scales_mask] = qmap[negative_scales_mask].flip(1)

    # max_abs_vals = qmap.abs().max(dim=1, keepdim=True).values
    # new_qmap = qmap / max_abs_vals
    return qmap


def change_qmap_with_offsets(offsets, b, q_scales, q_biases, max1):
    offsets = offsets.to(q_scales.device)
    qmap = vectorized_create_normal_map(offsets=offsets, total_bits=b)
    # print(max1.shape)
    # print(q_scales.unsqueeze(1).shape)
    scale_factors = max1 * q_scales.unsqueeze(1)
    # print(scale_factors.shape)
    qmap = qmap * scale_factors + q_biases.unsqueeze(1)
    # print(qmap.shape)

    negative_scales_mask = q_scales < 0
    qmap[negative_scales_mask] = qmap[negative_scales_mask].flip(1)

    # max_abs_vals = qmap.abs().max(dim=1, keepdim=True).values
    # new_qmap = qmap / max_abs_vals
    return qmap


def group_tensor(input: torch.Tensor, gp_sz: int):
    r"""Group tensor into subtensors of size 'gp_sz'
    """
    if not gp_sz > 0:
        raise ValueError("group size need to be a positive integer, but found {}".format(gp_sz))

    input_flatten = input.flatten()
    num_features = input_flatten.shape[0] 

    # Reshape the tensor into group
    if num_features % gp_sz != 0:
        # Padding
        new_num_features = (num_features // gp_sz + 1) * gp_sz
        delta = new_num_features - num_features
        input_flatten = torch.cat([input_flatten,
                                   torch.zeros([delta], dtype=input.dtype, device=input.device)], dim=0)

    input_groups = input_flatten.view(-1, gp_sz) # num_groups, group_size
    return input_groups


def recon_grouped_tensor(grouped_tensor: torch.Tensor, shape) -> torch.Tensor :
    r"""Reconstruction the tensor to original (or specific) shape
    """
    numel = shape.numel()
    recon_flatten = grouped_tensor.flatten()[:numel]
    recon = recon_flatten.view(shape)
    return recon


def _max_reduce_except_dim(tensor, dim):
    # Computes max along all dimensions except the given dim.
    # If tensor is a scalar, it returns tensor.
    rank = len(tensor.shape)
    result = tensor
    if rank > 0:
        assert dim < rank
        for d in range(rank):
            if d != dim:
                result = result.max(dim=d, keepdim=True).values
    return result

'''
def nonlinear_quant(qx, qmap, b, round_type='sr', scale_type='group'):
    qmaplen = len(qmap)
    qx.clamp_(qmap[0], qmap[-1])
    floor_idx = ((qx.unsqueeze(-1) >= qmap).sum(dim=-1) - 1).clamp_(0, qmaplen - 1)
    next_idx = (floor_idx + 1).clamp_max_(qmaplen - 1)
    Z = qmap[next_idx] - qmap[floor_idx]
    Z[Z <= 0] = 1.
    proba = (qx - qmap[floor_idx]) / Z
    proba = torch.bernoulli(proba, generator=lpmm_generator)
    idx = (floor_idx + proba).round_().to(torch.int)

    return idx
'''


def nonlinear_quant_grouped(name, qx, qmap, b, round_type='nearest', scale_type='group'):
    # device = qx.device  # Assuming qx is already on the desired device
    # qmap = qmap.to(device)  # Move qmap to the same device as qx
    # Assuming qx is reshaped or indexed to match groups in qmap if necessary
    num_groups, qmaplen = qmap.shape
    original_qx = qx.clone()  # Copy of original values for error calculation
    # Ensure qx is within the range of each group's qmap
    qx = qx.clamp(qmap[:, 0].unsqueeze(-1), qmap[:, -1].unsqueeze(-1))
    # Calculate floor indices for each group
    floor_idx = ((qx.unsqueeze(-1) >= qmap.unsqueeze(1)).sum(dim=-1) - 1).clamp_(0, qmaplen - 1)
    # Calculate next indices, ensuring it does not exceed qmaplen - 1
    next_idx = (floor_idx + 1).clamp_max_(qmaplen - 1)
    # Calculate Z for each group
    Z = torch.gather(qmap, 1, next_idx) - torch.gather(qmap, 1, floor_idx)
    Z[Z <= 0] = 1.  # Ensure no division by zero
    # Calculate proportional distance for stochastic rounding
    prop_dist = (qx - torch.gather(qmap, 1, floor_idx)) / Z
    if round_type == 'sr':
        # Apply stochastic rounding
        proba = torch.bernoulli(prop_dist)
        # Calculate final indices
        idx = (floor_idx + proba).round().to(torch.int8) # torch.int -> 32bit
        # print(idx.dtype)
    elif round_type == 'nearest':
        # Nearest rounding: round relative distances to the nearest integer (0 or 1)
        rounded_idx = (prop_dist >= 0.5).long()
        # Calculate final indices based on nearest rounding
        idx = floor_idx + rounded_idx
        idx = idx.round().to(torch.int8)  # Convert to int8 if desired

    # Gather the quantized values back
    quantized_values = torch.gather(qmap, 1, idx.to(torch.int64))

    # Calculate the L2 norm for the quantization error
    quantization_error = torch.norm(original_qx - quantized_values, p=3, dim=1)  # L2 norm along group_size dimension
    if name == 'model.layers.0.self_attn.k_proj':
        print(quantization_error)
        print(torch.norm(quantization_error, p=3))
        # print(quantized_values)

    return idx, quantization_error


def nonlinear_dequant(name, qx, qmap, b, shape, round_type='sr', scale_type='group'):
    # device = qx.device  # Assuming qx is already on the desired device
    # qmap = qmap.to(device)  # Move qmap to the same device as qx
    # Ensure qx indices are within the valid range
    num_groups, qmaplen = qmap.shape
    qx = qx.clamp(0, qmaplen - 1)
    
    # Reshape qx to match the expected output shape
    # Assuming qx is a flat index tensor; adjust shape if it's not
    # qx = qx.view(*shape)
    
    # Map indices back to quantization values
    # For each index in qx, find the corresponding value in qmap
    dequant_values = torch.gather(qmap, 1, qx.long().to(torch.int64))
    # if name == 'model.layers.0.self_attn.k_proj':
    #     print(dequant_values)
    
    return dequant_values


def create_normal_map(offset=0.9677083, use_extra_value=True, use_adaptive_map=False, total_bits=4, mu=0.0, sigma=1.0, min_weight=None, max_weight=None, scaled=True): 
    if use_adaptive_map:
        if not min_weight or not max_weight:
            if mu >= 0.0:
                if norm.ppf(1-offset) >= 0.0:
                    # offset = 0.99999
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
                else:
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits-1), loc=mu, scale=sigma).tolist()
                    # v = v + [-v[2**(total_bits-1)-1]]
                    v = v + [0]
                    '''
                    v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1)+1)[:-1], loc=0.0, scale=sigma).tolist()
                    v2 = [0] # [0]
                    v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1], loc=0.0, scale=sigma).tolist()
                    v = v1 + v2 + v3
                    '''
                    
            else:
                if norm.ppf(offset) <= 0.0:
                    # offset = 0.99999
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
                else:
                    v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits-1), loc=mu, scale=sigma).tolist()
                    # v = v + [-v[2**(total_bits-1)-1]]
                    v = v + [0]
                    '''
                    v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=0.0, scale=sigma).tolist()
                    v2 = [0] # [0]
                    v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1)+1)[:-1], loc=0.0, scale=sigma).tolist()
                    v = v1 + v2 + v3
                    '''
                    
        else:
            min_cdf, max_cdf = norm.cdf(min_weight, loc=mu, scale=sigma), norm.cdf(max_weight, loc=mu, scale=sigma)
            # print('min cdf is {} and max cdf is {}'.format(min_cdf, max_cdf))
            v = norm.ppf(torch.linspace(min_cdf, max_cdf, 2**total_bits), loc=mu, scale=sigma).tolist()
    else:
        normal_dist = Normal(loc=mu, scale=sigma)
        quantiles = torch.linspace(1 - offset, offset, 2 ** total_bits)
    
    values = normal_dist.icdf(quantiles)
    if scaled:
        v_995 = normal_dist.icdf(torch.tensor(0.995))
        values /= v_995
    else:
        values /= torch.abs(values).max()

    return values.detach() # values.detach()
    #     '''
    #     if use_extra_value:
    #         # one more positive value, this is an asymmetric type
    #         v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1)+1)[:-1]).tolist()
    #         # v2 = [0]*(2**total_bits-15) ## we have 15 non-zero values in this data type
    #         v2 = [0] # [0]
    #         # v3 = (-norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=mu, scale=sigma)).tolist()
    #         v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1]).tolist()
    #     else:
    #         v1 = norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1]).tolist()
    #         # v2 = [0]*(2**total_bits-14) ## we have 14 non-zero values in this data type
    #         v2 = [0]*2 # [0]*2
    #         # v3 = (-norm.ppf(torch.linspace(offset, 0.5, 2**(total_bits-1))[:-1], loc=mu, scale=sigma)).tolist()
    #         v3 = norm.ppf(torch.linspace(1-offset, 0.5, 2**(total_bits-1))[:-1]).tolist()

    #     v = v1 + v2 + v3
    #     '''
    #     v = norm.ppf(torch.linspace(offset, 1-offset, 2**total_bits), loc=mu, scale=sigma).tolist()
    # values = torch.Tensor(v)
    # values = values.sort().values
    
    # # values /= torch.abs(values).max()
    # v_995 = norm.ppf(torch.linspace(0.995, 1-0.995, 2**total_bits), loc=mu, scale=sigma)[0]
    # values /= torch.tensor(v_995)

    # return values


def manual_linspace(start, end, steps):
    # Expands start and end to match the desired shape (group_size, steps)
    # Then, computes a linspace for each row based on the expanded start and end values
    device = start.device
    step_size = (end - start) / (steps - 1)
    return start.unsqueeze(1) + step_size.unsqueeze(1) * torch.arange(steps).to(device)


def vectorized_create_normal_map(offsets, total_bits, mu=0.0, sigma=1.0):
    group_size = offsets.size(0)
    # Adjust the linspace calculation
    num_q_values = 2 ** total_bits
    start_values = 1.0 - offsets
    end_values = offsets

    # Manually calculate linspace for each offset
    quantiles = manual_linspace(start_values, end_values, num_q_values)  # Shape: (group_size, num_q_values)

    # Flatten quantiles to use with icdf and later reshape
    quantiles_flat = quantiles.flatten()
    normal_dist = Normal(mu, sigma)
    # print(quantiles_flat.dtype)
    qmap_flat = normal_dist.icdf(quantiles_flat)

    # Reshape qmap back to the original intended 2D shape
    qmap = qmap_flat.view(group_size, num_q_values)
    v_995 = normal_dist.icdf(torch.tensor(0.995))
    qmap /= v_995

    return qmap.detach() # qmap.detach()


def plot_weight_distribution(p_data_tensor, quantized_values, quantized_values_2, dir_name='plots'):
    current_dir = os.getcwd()
    output_dir = os.path.join(current_dir, dir_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    p_data_np = p_data_tensor.to(torch.float32).detach().cpu().numpy()
    p_data_flattened = p_data_np.flatten()
    quantized_values = quantized_values.to(torch.float32).detach().cpu().numpy()
    quantized_values_2 = quantized_values_2.to(torch.float32).detach().cpu().numpy()

    # Set Seaborn style for better aesthetics
    # sns.set()

    # Create a histogram
    # g = sns.displot(p_data_flattened, kind='hist', log_scale=(False, True), color='darkblue')
    # g.set_titles('Distribution of Weights')
    # g.set_axis_labels('Weight Value', 'Density')
    # g.savefig(os.path.join(output_dir, 'weight_distribution.png'))
    plt.hist(p_data_flattened, bins=500, color='green', log=True, density=True) # density=False, log=True
    plt.title('Distribution of Weights', fontsize=16, fontweight='bold')
    plt.xlabel('Weight Value', fontsize=16, fontweight='bold')
    plt.ylabel('Density', fontsize=16, fontweight='bold')
    plt.xticks(fontsize=11)
    plt.yticks(fontsize=11)

    # print(quantized_values)
    # Set the y-coordinate for the red points
    ymin, ymax = plt.ylim()
    y_coord_for_points = ymax * 1e-4  # 1% of the maximum y-value, 1e-6
    # Add quantized values as red points on the x-axis
    for q_value in quantized_values:
        plt.scatter([q_value], [y_coord_for_points], color='red', s=10)  # 's' is the size of the point, 3
        # plt.axvline(x=q_value, color='red', linestyle='--', linewidth=0.5, ymin=ymin, ymax=y_coord_for_red_points)
        plt.vlines(q_value, ymin, y_coord_for_points, colors='red', linestyles='--', linewidth=1.5) # 0.5

    for q_value in quantized_values_2:
        plt.scatter([q_value], [y_coord_for_points], color='blue', s=10)  # 's' is the size of the point, 3
        # plt.axvline(x=q_value, color='red', linestyle='--', linewidth=0.5, ymin=ymin, ymax=y_coord_for_red_points)
        plt.vlines(q_value, ymin, y_coord_for_points, colors='blue', linestyles='--', linewidth=1.5) # 0.5

    # Display the plot
    plt.savefig(os.path.join(output_dir, 'weight_distribution_nf2_l3.png'))
    plt.close()


def plot_offset_quantization_error(offsets, quantization_error, dir_name='plots'):
    current_dir = os.getcwd()
    output_dir = os.path.join(current_dir, dir_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    offsets.sort()
    quantization_error = quantization_error.to(torch.float32).detach().cpu().numpy()

    # plt.figure(figsize=(8, 6))
    plt.plot(offsets, quantization_error, marker='o', linestyle='-', color='b')
    plt.xlabel('CDF Offsets', fontsize=16, fontweight='bold')
    plt.ylabel('Quantization Error', fontsize=16, fontweight='bold')
    plt.title('L3 norm', fontsize=16, fontweight='bold')
    plt.xticks(fontsize=11)
    plt.yticks(fontsize=11)
    # plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'offset_quantization_error_l3.png'))
    plt.close()