import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F


def check_equal(t1, t2):
    if t1.shape != t2.shape:
        return False

    return torch.allclose(t1, t2)


def show_value(tensor):
    weights = tensor.cpu().detach().numpy()

    plt.hist(weights.ravel(), bins=50)
    plt.title("Weight Distribution of nn.Conv2d Layer")
    plt.xlabel("Weight values")
    plt.ylabel("Frequency")
    plt.show()


def show_unique(tensor):
    arr = tensor.cpu().detach().numpy()
    vals, counts = np.unique(arr, return_counts=True)

    print("(Unique values):", vals)
    print("(Counts):", counts)
    print("(Total unique count):", len(vals))


def sample_and_sort(n_samples=100, range_start=0, range_end=999):
    sampled_numbers = random.sample(range(range_start, range_end + 1), n_samples)
    sorted_numbers = sorted(sampled_numbers)
    return sorted_numbers


def calculate_sparsity(tensor):
    total_elements = tensor.numel()
    non_zero_elements = (tensor != 0).sum().item()
    sparsity = 1 - (non_zero_elements / total_elements)
    return sparsity


def float_quantize(x: torch.Tensor, e_bits: int, m_bits: int) -> torch.Tensor:

    if x.dtype != torch.float32:
        x = x.float()
    x_int = x.view(torch.int32)

    sign_mask = 0x80000000  
    expo_mask = 0x7F800000
    frac_mask = 0x007FFFFF

    sign_bit = (x_int & sign_mask) >> 31
    
    exponent = (x_int & expo_mask) >> 23
    
    fraction = x_int & frac_mask

    old_bias = 127
    new_bias = (1 << (e_bits - 1)) - 1 
    e_value = exponent.to(torch.int32) - old_bias
    e_min = -new_bias  
    e_max = new_bias
    e_clamped = torch.clamp(e_value, e_min, e_max)

    exponent_new = e_clamped + new_bias  
    exponent_new = torch.clamp(exponent_new, 0, (1 << e_bits) - 1)

  
    if m_bits < 23:
        shift_amount = 23 - m_bits
        round_offset = 1 << (shift_amount - 1)
        frac_rounded = (fraction + round_offset) >> shift_amount
        frac_rounded = torch.clamp(frac_rounded, 0, (1 << m_bits) - 1)
    else:
        frac_rounded = fraction.clone()


    exponent_new = torch.clamp(
        exponent_new, 0, 255
    )

    sign_new = sign_bit << 31
    expo_new = exponent_new << 23
    frac_new = frac_rounded

    x_int_new = sign_new | expo_new | frac_new

    
    x_quant = x_int_new.view(torch.float32)
    return x_quant


def save_BASQ_quantize_txt(save_path, model_name, split_layer, IF_list, e_bits, m_bits):
    save_dir = os.path.join(
        save_path, model_name, f"BASQ_Quant{e_bits + m_bits}", f"SL{split_layer}"
    )
    os.makedirs(save_dir, exist_ok=True)
    cnt = 0

    IF_quant = float_quantize(IF_list[0], e_bits=e_bits, m_bits=m_bits)
    show_unique(IF_quant)
    for IF in IF_list:
        IF_quant = float_quantize(IF, e_bits=e_bits, m_bits=m_bits)
        IF_quant = IF_quant.view(-1).detach().cpu().tolist()

        file_name = os.path.join(save_dir, f"{cnt}.txt")
        with open(file_name, "w") as f:
            for val in IF_quant:
                f.write(f"{val}\n")
        cnt += 1


def duq_quantize(x, a, alpha, b, beta, N_lv):
    N_lv = 2**N_lv

    a_prime = F.softplus(a)
    alpha_prime = F.softplus(alpha)

    x_hat = torch.clamp((x - b) / a_prime, min=0, max=1)
    x_bar = torch.round((N_lv - 1) * x_hat) / (N_lv - 1)
    x_tilde = alpha_prime * x_bar + beta

    return x_tilde


def save_PROFIT_quantize_txt(
    save_path, model_name, split_layer, IF_list, a, alpha, b, beta, Q
):
    save_dir = os.path.join(
        save_path, model_name, f"PROFIT_Quant{Q}", f"SL{split_layer}"
    )
    os.makedirs(save_dir, exist_ok=True)
    cnt = 0

    IF_quant = duq_quantize(IF_list[0], a, alpha, b, beta, Q)
    show_unique(IF_quant)
    for IF in IF_list:
        IF_quant = duq_quantize(IF, a, alpha, b, beta, Q)
        IF_quant = IF_quant.view(-1).detach().cpu().tolist()

        file_name = os.path.join(save_dir, f"{cnt}.txt")
        with open(file_name, "w") as f:
            for val in IF_quant:
                f.write(f"{val}\n")
        cnt += 1
