import math
import random
import os
import yaml

import numpy as np
import torch
import torch.nn as nn

from fast_hadamard_transform import hadamard_transform

import nsn_tools


def load_yaml_as_dict(file_path):
    with open(file_path, 'r') as f:
        data = yaml.safe_load(f)
    return data


def set_all_seeds(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def hist_torch(input, bins):
    """
    https://github.com/pytorch/pytorch/issues/99719
    """
    bin_idx = torch.bucketize(input, bins[1:], right=True)
    hist = torch.zeros(size=(input.shape[0], bins.shape[0]), device=input.device, dtype=input.dtype).scatter_reduce(
        dim=1, index=bin_idx, src=torch.ones_like(input),
        reduce="sum", include_self=False)
    return hist[:,:-1]


def calculate_kl_divergence(data):
    """
    Batched kl divergence
    """

    bins = torch.linspace(-5, 5, 100, device=data.device)  # 100 bins from -5 to 5
    p_hist = hist_torch(data, bins=bins)
    p_dist = p_hist / p_hist.sum(dim=-1, keepdim=True)
    q_dist = torch.diff(torch.distributions.Normal(0, 1).cdf(bins))
    q_dist = q_dist[None, :].expand(p_dist.shape[0], q_dist.shape[-1])
    valid = (p_dist > 0) & (q_dist > 0)  # log(0) 방지
    kl_divergence_hist = torch.sum(p_dist[valid] * torch.log(p_dist[valid] / q_dist[valid]))
    kl_divergence_hist /= data.shape[0]
    return kl_divergence_hist.item()


def pack_idx_signs(idx: torch.IntTensor,
                   sign_mask: torch.BoolTensor):
    B, H, L, D = sign_mask.shape

    res = torch.zeros(size=(B, H, L, D//16), dtype=torch.int32, device=sign_mask.device)
    idx = idx.to(torch.int32)
    sign_mask = sign_mask.to(torch.int32)

    sign_mask = sign_mask.reshape(B, H, L, D//16, 16)
    idx = idx.reshape(B, H, L, D//16, 2)

    res += idx[..., 0] + \
        (idx[..., 1] << 8) 
    
    for i in range(16):
        res += (sign_mask[..., i] << (16+i))
    
    return res


@torch.no_grad()
def rotate_v_proj(v_proj_layer: torch.nn.Linear,
                  head_dim: int):
    w = v_proj_layer.weight
    w_t = w.transpose(0, 1)
    w_t_shape = w_t.shape
    w_t = w_t.reshape(-1, head_dim)
    w_t = hadamard_transform(w_t, 1/math.sqrt(w_t.shape[-1]))
    w_t = w_t.reshape(*w_t_shape)
    w = w_t.transpose(0, 1).contiguous()
    v_proj_layer.weight.copy_(w)

    if v_proj_layer.bias:
        b = v_proj_layer.bias
        b_shape = b.shape
        b = b.reshape(-1, head_dim)
        b = hadamard_transform(b, 1/math.sqrt(b.shape[-1]))
        b = b.reshape(*b_shape).contiguous()
        v_proj_layer.bias.copy_(b)
    return


@torch.no_grad()
def rotate_o_proj(o_proj_layer: torch.nn.Linear,
                  head_dim: int):
    w = o_proj_layer.weight
    w_shape = w.shape
    w = w.reshape(-1, head_dim)
    w = hadamard_transform(w, 1/math.sqrt(w.shape[-1]))
    w = w.reshape(*w_shape).contiguous()
    o_proj_layer.weight.copy_(w)
    return


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_single(v, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the key tensor only.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    v_embed = (v * cos) + (rotate_half(v) * sin)
    return v_embed


def quantize_tensor_4bit(w: torch.tensor, group_size) -> torch.tensor:
    # Asymmetric RTN (Round-To-Nearest)
    savedShape = w.shape
    assert savedShape[-1] % 8 == 0, "the last dimension should be divisible by 8 (for packing)"
    # assert w.dim() == 2 

    n_bits = 4
    assert group_size % 8 == 0, "group sizes of multiples of 8 are only supported" 
    if group_size > 0:
        assert w.shape[-1] % group_size == 0
        w = w.reshape(-1, group_size) # row-major order

    w_max = w.amax(dim=-1, keepdim=True)
    w_min = w.amin(dim=-1, keepdim=True)
    scale = torch.clamp((w_max - w_min) / (2 ** n_bits - 1), min=1e-5)
    w = w - w_min
    w.div_(scale)
    w = w.clamp_(0, 2 ** n_bits - 1).round_().int()
    w = w.reshape(*savedShape[:-1], savedShape[-1]//8, 8)
    ret = torch.zeros((*savedShape[:-1], savedShape[-1]//8), dtype=torch.int32, device=w.device)
    ret += w[..., 0] + \
        (w[..., 1] << 4) + \
        (w[..., 2] << 8) + \
        (w[..., 3] << 12) + \
        (w[..., 4] << 16) + \
        (w[..., 5] << 20) + \
        (w[..., 6] << 24) + \
        (w[..., 7] << 28)

    scale = scale.reshape(*savedShape[:-1], -1)
    w_min = w_min.reshape(*savedShape[:-1], -1)
    return ret, scale, w_min


def dequantize_tensor_4bit(x: torch.tensor,
                      scale: torch.tensor,
                      offset: torch.tensor):

    ret = torch.zeros((*x.shape, 8), dtype=scale.dtype, device=scale.device)

    for i in range(8):
        ret[..., i] = ((x >> (i*4)) & 15)

    ori_shape = ret.shape
    ret = ret.reshape(*scale.shape, -1)
    ret = ret * scale[..., None] + offset[..., None]
    ret = ret.reshape((*ori_shape[:-2], -1))
    return ret


def pseudo_quantize_tensor_4bit(x, group_size):
    return dequantize_tensor_4bit(*quantize_tensor_4bit(x, group_size))


def get_scale_adjustment_function(hidden_dim, n_bits):
    scale_adjustment_function_dict = {
        (128, 2): nsn_tools.adjust_scale_dq_d128,
        (128, 1): nsn_tools.adjust_scale_dq_1bit_d128,
    }
    return scale_adjustment_function_dict[(hidden_dim, n_bits)]


def get_dist_argmin_batched_function(hidden_dim):
    dist_argmin_batched_function_dict = {
        4: nsn_tools.dist_argmin_half_batched_d4,
        8: nsn_tools.dist_argmin_half_batched_d8,
        9: nsn_tools.dist_argmin_half_batched_d9,
        10: nsn_tools.dist_argmin_half_batched_d10,
    }
    return dist_argmin_batched_function_dict[hidden_dim]

def get_restore_function(window_size, hidden_dim, n_bits):
    restore_function_dict = {
        (64, 128, 2): nsn_tools.restore_quantized_dq_ws64_d128,
        (64, 128, 1): nsn_tools.restore_quantized_dq_1bit_ws64_d128,
        (32, 128, 2): nsn_tools.restore_quantized_dq_ws32_d128,
        (32, 128, 1): nsn_tools.restore_quantized_dq_1bit_ws32_d128,
        (128, 128, 2): nsn_tools.restore_quantized_dq_ws128_d128,
        (128, 128, 1): nsn_tools.restore_quantized_dq_1bit_ws128_d128,
    }
    return restore_function_dict[(window_size, hidden_dim, n_bits)]


def get_dot_product_fused_function(window_size, hidden_dim, n_bits):
    dot_product_fused_function_dict = {
        (64, 128, 2): nsn_tools.quantized_dot_product_fused_dq_ws64_d128,
        (64, 128, 1): nsn_tools.quantized_dot_product_fused_dq_1bit_ws64_d128,
        (32, 128, 2): nsn_tools.quantized_dot_product_fused_dq_ws32_d128,
        (32, 128, 1): nsn_tools.quantized_dot_product_fused_dq_1bit_ws32_d128,
        (128, 128, 2): nsn_tools.quantized_dot_product_fused_dq_ws128_d128,
        (128, 128, 1): nsn_tools.quantized_dot_product_fused_dq_1bit_ws128_d128,
    }
    return dot_product_fused_function_dict[(window_size, hidden_dim, n_bits)]


def get_dot_product_fused_residual_function(window_size, hidden_dim, n_bits):
    dot_product_fused_residual_function_dict = {
        (64, 128, 2): nsn_tools.quantized_dot_product_fused_residual_dq_ws64_d128,
        (64, 128, 1): nsn_tools.quantized_dot_product_fused_residual_dq_1bit_ws64_d128,
        (32, 128, 2): nsn_tools.quantized_dot_product_fused_residual_dq_ws32_d128,
        (32, 128, 1): nsn_tools.quantized_dot_product_fused_residual_dq_1bit_ws32_d128,
        (128, 128, 2): nsn_tools.quantized_dot_product_fused_residual_dq_ws128_d128,
        (128, 128, 1): nsn_tools.quantized_dot_product_fused_residual_dq_1bit_ws128_d128,
    }
    return dot_product_fused_residual_function_dict[(window_size, hidden_dim, n_bits)]


def get_weighted_sum_function(window_size, hidden_dim, n_bits):
    weighted_sum_function_dict = {
        (64, 128, 2): nsn_tools.quantized_weighted_sum_dq_ws64_d128,
        (64, 128, 1): nsn_tools.quantized_weighted_sum_dq_1bit_ws64_d128,
        (32, 128, 2): nsn_tools.quantized_weighted_sum_dq_ws32_d128,
        (32, 128, 1): nsn_tools.quantized_weighted_sum_dq_1bit_ws32_d128,
        (128, 128, 2): nsn_tools.quantized_weighted_sum_dq_ws128_d128,
        (128, 128, 1): nsn_tools.quantized_weighted_sum_dq_1bit_ws128_d128,
    }
    return weighted_sum_function_dict[(window_size, hidden_dim, n_bits)]


def get_weighted_sum_residual_function(window_size, hidden_dim, n_bits):
    weighted_sum_function_dict = {
        (64, 128, 2): nsn_tools.quantized_weighted_sum_residual_dq_ws64_d128,
        (64, 128, 1): nsn_tools.quantized_weighted_sum_residual_dq_1bit_ws64_d128,
        (32, 128, 2): nsn_tools.quantized_weighted_sum_residual_dq_ws32_d128,
        (32, 128, 1): nsn_tools.quantized_weighted_sum_residual_dq_1bit_ws32_d128,
        (128, 128, 2): nsn_tools.quantized_weighted_sum_residual_dq_ws128_d128,
        (128, 128, 1): nsn_tools.quantized_weighted_sum_residual_dq_1bit_ws128_d128,
    }
    return weighted_sum_function_dict[(window_size, hidden_dim, n_bits)]
