import numpy as np
import tqdm as tqdm
import torch
from scipy.stats import gaussian_kde
from simulator import RIRGenSimulator, ANSTISIGNAL_ERROR, NOISE_ERROR
from utils import get_device



device = get_device()
reverberation_times = [0.15, 0.175, 0.2, 0.225, 0.25]
sr = 16000

simulator_rir = RIRGenSimulator(sr, reverberation_times, device, hp_filter=True, v=3)

b_count = 7
epsilon = 1e-10

def PDF_calc(d_list, y_list, bin_count=7):
    """
    Calculate the joint and marginal probability density functions (PDFs) of two sequences.
    Args:
        d_list (list): List of sequences for the first variable.
        y_list (list): List of sequences for the second variable.
        bin_count (int): Number of bins for the histogram.
    Returns:
        pdf1 (numpy.ndarray): Marginal PDF of the first variable.
        pdf2 (numpy.ndarray): Marginal PDF of the second variable.
        pdf (numpy.ndarray): Joint PDF of the two variables.
    """
    # Initialize the PDFs
    pdf = np.zeros((bin_count, bin_count))
    pdf1 = np.zeros(bin_count)
    pdf2 = np.zeros(bin_count)

    bw = 1/bin_count

    for subsequence1, subsequence2 in zip(d_list, y_list):

        # Make subsequence1 and subsequence2 to have the same shape
        if subsequence1.shape != subsequence2.shape:
            if subsequence1.shape[0] > subsequence2.shape[0]:
                subsequence1 = subsequence1[:subsequence2.shape[0]]
            else:
                subsequence2 = subsequence2[:subsequence1.shape[0]]

        subsequence1 =  subsequence1.squeeze()
        subsequence2 = subsequence2.squeeze()
        subsequence1 = np.array(subsequence1)
        subsequence2 = np.array(subsequence2)

        kde = gaussian_kde(np.vstack([subsequence1, subsequence2]), bw_method=bw)
        kde1 = gaussian_kde(subsequence1, bw_method=bw)
        kde2 = gaussian_kde(subsequence2, bw_method=bw)
        
        # x_grid = np.linspace(np.min(d_list), np.max(d_list), bin_count)
        # y_grid = np.linspace(np.min(y_list), np.max(y_list), bin_count)

        x_grid = np.linspace(-0.1, 0.1, bin_count)
        y_grid = np.linspace(-0.1, 0.1, bin_count)

        X_grid, Y_grid = np.meshgrid(x_grid, y_grid)

        marginal_pdf_1 = kde1.evaluate(x_grid)
        marginal_pdf_2 = kde2.evaluate(y_grid)
        joint_pdf_values = kde.evaluate(np.vstack([X_grid.ravel(), Y_grid.ravel()])).reshape(X_grid.shape) # .ravel() flattens the grid
    
        marginal_pdf_1 /= (np.trapz(marginal_pdf_1, x_grid) + epsilon)
        marginal_pdf_2 /= (np.trapz(marginal_pdf_2, y_grid) + epsilon)
        joint_pdf_values /= (np.trapz(np.trapz(joint_pdf_values, x_grid, axis=0), y_grid) + epsilon)

        pdf1 += marginal_pdf_1
        pdf2 += marginal_pdf_2
        pdf += joint_pdf_values
    
    pdf /= len(d_list)
    pdf1 /= len(d_list)
    pdf2 /= len(d_list)

    return pdf1, pdf2, pdf


def Mi_bound(pt_list, y_list, t60 ):

    pdf_pt, pdf_y, joint_pdf = PDF_calc(pt_list, y_list, bin_count=b_count)


    marginal_pdf_X = pdf_pt/np.sum(pdf_pt)
    marginal_pdf_Y = pdf_y/np.sum(pdf_y)
    joint_pdf_values = joint_pdf/np.sum(joint_pdf)    

    # Discrete entropy calculation
    entropy_X = -np.sum(marginal_pdf_X * np.log(marginal_pdf_X + np.finfo(float).eps))# * (x_grid[1] - x_grid[0]))
    entropy_Y = -np.sum(marginal_pdf_Y * np.log(marginal_pdf_Y + np.finfo(float).eps))# * (y_grid[1] - y_grid[0]))
    entropy_XY = -np.sum(joint_pdf_values * np.log(joint_pdf_values + np.finfo(float).eps))# * (x_grid[1] - x_grid[0]) * (y_grid[1] - y_grid[0]))


    MI_entropies = entropy_X + entropy_Y - entropy_XY

    # mi_list.append(MI_entropies)

    coeff = 1-(MI_entropies/entropy_X)

    primary_path = simulator_rir.rirs[(t60, NOISE_ERROR)].view(-1).cpu().numpy()
    primary_path_energy = np.sum(np.abs(primary_path) ** 2)

    mi_bound = coeff * primary_path_energy

    mi_bound_dB = 10*np.log10(mi_bound + epsilon)


    return mi_bound, mi_bound_dB