from re import L
import numpy as np
from pyparsing import line
import torch
from utils.binary import high_order_residual
from utils.mask import generate_structural_mask


def error_computing(origin_matrix, quantized_matrix):
    mse = torch.mean((origin_matrix - quantized_matrix) ** 2)
    return mse


def calculate_percentage_and_vasince_original(weights, abs_weights, bin_edges):
    percentages = []
    vasinces = []
    accum_percentages = [0]
    total_elements = abs_weights.numel()
    for i in range(len(bin_edges) - 1):
        bin_mask = (abs_weights >= bin_edges[i]) & (abs_weights < bin_edges[i + 1])
        bin_weights = weights[bin_mask]
        percentages.append(bin_weights.numel() / total_elements * 100)
        accum_percentages.append(accum_percentages[-1] + percentages[-1])
        vasinces.append(torch.var(bin_weights))
    return percentages, vasinces, accum_percentages


"""
Include main method to search the rate for 2-bit salient data columns and the optimal split for 1-bit data
"""


def structural_searching(origin_matrix, up_lim=30):
    minimal_value = float("inf")
    minimal_value_0 = float("inf")

    true_counts = origin_matrix.abs().sum(dim=0)

    error = []
    lines = []

    _, top_braq_2_columns = torch.topk(true_counts, up_lim)
    for i in range(1, up_lim):
        mask4 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(
            origin_matrix.device
        )
        mask4[:, top_braq_2_columns[:i]] = True
        group4 = high_order_residual(origin_matrix, mask4, order=2)

        group1_2_3 = high_order_residual(origin_matrix, ~mask4, order=2)
        quantize_error_0 = error_computing(origin_matrix, group1_2_3 + group4)
        error.append(quantize_error_0.item())
        lines.append(i)
        if quantize_error_0 < minimal_value_0:
            minimal_value_0 = quantize_error_0
            optimal_split_0 = i

    _, top_braq_2_columns = torch.topk(true_counts, optimal_split_0)
    mask4 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(
        origin_matrix.device
    )
    mask4[:, top_braq_2_columns] = True

    group4 = high_order_residual(origin_matrix, mask4, order=2)

    search_matrix = origin_matrix * (~mask4)

    flat_abs_tensor = torch.abs(search_matrix).view(-1)
    percentiles = torch.linspace(0.10, 0.90, 160).to(origin_matrix.device)
    percentile_values = torch.tensor(
        np.quantile(
            flat_abs_tensor.detach().cpu().numpy(),
            q=percentiles.cpu().numpy(),
            axis=None,
            keepdims=False,
        )
    ).to(origin_matrix.device)

    for split_value_1 in percentile_values:
        split_value_2 = 2 * split_value_1
        if split_value_2 > percentile_values[-1]:

            continue

        mask1, mask2, mask3 = generate_structural_mask(
            search_matrix, mask4, split_value_1, split_value_2
        )
        group1 = high_order_residual(origin_matrix, mask1, order=1)
        group2 = high_order_residual(origin_matrix, mask2, order=1)
        group3 = high_order_residual(origin_matrix, mask3, order=1)

        quantize_error = error_computing(
            origin_matrix, group1 + group2 + group3 + group4
        )
        if quantize_error < minimal_value:
            minimal_value = quantize_error
            optimal_split_1 = split_value_1
            optimal_split_2 = split_value_2

    return optimal_split_1, optimal_split_2, mask4


def find_optimal_split(group_max, origin_matrix, border):
    optimal_split = None
    minimal_value = float("inf")
    searching_steps = torch.arange(0.1, 0.8, 0.01)
    searching_steps = searching_steps * group_max

    group3 = high_order_residual(
        origin_matrix, torch.abs(origin_matrix) > border, order=2
    )
    for split_value in searching_steps:

        group1 = high_order_residual(
            origin_matrix,
            (torch.abs(origin_matrix) > split_value)
            & (torch.abs(origin_matrix) <= border),
            order=1,
        )
        group2 = high_order_residual(
            origin_matrix, torch.abs(origin_matrix) <= split_value, order=1
        )

        quantize_error = error_computing(origin_matrix, group1 + group2 + group3)
        if quantize_error < minimal_value:
            minimal_value = quantize_error
            optimal_split = split_value

    return optimal_split, minimal_value
