import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import os
import torch.nn.functional as F


def check_equal(t1, t2):
    if t1.shape != t2.shape:
        return False
    
    return torch.allclose(t1, t2)


def sample_and_sort(n_samples=100, range_start=0, range_end=999):
    sampled_numbers = random.sample(range(range_start, range_end + 1), n_samples)
    sorted_numbers = sorted(sampled_numbers)
    return sorted_numbers


def calculate_sparsity(tensor):
    total_elements = tensor.numel()
    non_zero_elements = (tensor != 0).sum().item()
    sparsity = 1 - (non_zero_elements / total_elements)
    return sparsity


def float_quantize(x: torch.Tensor, e_bits: int, m_bits: int) -> torch.Tensor:
    
    if x.dtype != torch.float32:
        x = x.float()

    
    x_int = x.view(torch.int32)

    sign_mask = 0x80000000
    expo_mask = 0x7F800000
    frac_mask = 0x007FFFFF

    sign_bit = (x_int & sign_mask) >> 31
    exponent = (x_int & expo_mask) >> 23
    fraction = (x_int & frac_mask)

    
    old_bias = 127
    new_bias = (1 << (e_bits - 1)) - 1

    
    e_value = exponent.to(torch.int32) - old_bias
    
    e_min = -new_bias
    e_max = new_bias
    e_clamped = torch.clamp(e_value, e_min, e_max)

    exponent_new = e_clamped + new_bias 
    exponent_new = torch.clamp(exponent_new, 0, (1 << e_bits) - 1)

    if m_bits < 23:
        shift_amount = 23 - m_bits
        round_offset = 1 << (shift_amount - 1)
        frac_rounded = (fraction + round_offset) >> shift_amount
        frac_rounded = torch.clamp(frac_rounded, 0, (1 << m_bits) - 1)
    else:
        frac_rounded = fraction.clone()

    exponent_new = torch.clamp(exponent_new, 0, 255)
    sign_new = sign_bit << 31
    expo_new = exponent_new << 23
    frac_new = frac_rounded

    x_int_new = sign_new | expo_new | frac_new

    x_quant = x_int_new.view(torch.float32)
    return x_quant


def save_BASQ_quantize_txt(save_path,model_name,split_layer,IF_list,e_bits, m_bits):
    save_dir = os.path.join(save_path, model_name,
                            f"BASQ_Quant{e_bits+m_bits}",
                            f"SL{split_layer}")
    os.makedirs(save_dir, exist_ok=True)
    cnt = 0

    IF_quant = float_quantize(IF_list[0], e_bits=e_bits, m_bits=m_bits)
    show_unique(IF_quant)
    for IF in IF_list:
        IF_quant = float_quantize(IF, e_bits=e_bits, m_bits=m_bits)
        IF_quant = IF_quant.view(-1).detach().cpu().tolist()

        file_name = os.path.join(save_dir, f"{cnt}.txt")
        with open(file_name, 'w') as f:
            for val in IF_quant:
                f.write(f"{val}\n")
        cnt+=1
        

def duq_quantize(x, a, alpha, b, beta, N_lv):
    N_lv = 2 ** N_lv

    a_prime = F.softplus(a)
    alpha_prime = F.softplus(alpha)

    x_hat = torch.clamp((x - b) / a_prime, min=0, max=1)
    x_bar = torch.round((N_lv - 1) * x_hat) / (N_lv - 1)
    x_tilde = alpha_prime * x_bar + beta

    return x_tilde

def save_PROFIT_quantize_txt(save_path,model_name,split_layer,IF_list,a, alpha, b, beta, Q):
    save_dir = os.path.join(save_path, model_name,
                            f"PROFIT_Quant{Q}",
                            f"SL{split_layer}")
    os.makedirs(save_dir, exist_ok=True)
    cnt = 0

    IF_quant = duq_quantize(IF_list[0], a, alpha, b, beta, Q)
    show_unique(IF_quant)
    for IF in IF_list:
        IF_quant =  duq_quantize(IF, a, alpha, b, beta, Q)
        IF_quant = IF_quant.view(-1).detach().cpu().tolist()

        file_name = os.path.join(save_dir, f"{cnt}.txt")
        with open(file_name, 'w') as f:
            for val in IF_quant:
                f.write(f"{val}\n")
        cnt+=1