import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from support_based_bound import support_boundary_calc
from information_theoretic_bound import Mi_bound
from tools.simulator import PyRoomSimulator, RIRGenSimulator, ANSTISIGNAL_ERROR, NOISE_ERROR
import torchaudio
from utils import get_device, process_nmse, delay_fct
import torch
import numpy as np


if __name__ == "__main__":
    device = get_device()
    
    reverberation_times = [0.15, 0.175, 0.2, 0.225, 0.25]
    t60 = 0.2  # Example T60 value
    sr = 16000 # sampling rate

    # Example usage

    # Data extraction
    __path_to_NoiseX_data = "path_to_your_data" # To fill    
    data_file_name =  __path_to_NoiseX_data + ".wav" 

    
    # Load the signal
    signal = torchaudio.load(data_file_name)[0].to(device)

    # Define the signal segment
    start = 0
    end = start + 3 # 3 seconds
    step = 3
    signal_segments = [signal[:, sr * (start + step * i):sr * (end + step * i)] for i in range(int(signal.shape[1] / sr / step))]
    
    # As an example, we consider only 50 samples of 3 second. Replicate it to more samples to get a better estimation of the bound
    signal_segments = signal_segments[:50]

    # Define the simulator
    simulator_rir = RIRGenSimulator(sr, reverberation_times, device, hp_filter=True, v=3)

    # Load ANC model 
    model = "model_loaded" # To fill


    pt_list = []
    y_list = []
    nmse_list = []
    sup_bound_list = []
    mi_bound_list = []

    
    for signal_segment in signal_segments:
        # print(signal_segment.shape)        

        # Primary path
        pt = simulator_rir.simulate(signal_segment, t60, NOISE_ERROR).to(device)
        pt_list.append(pt.cpu().numpy()[0])
    
        eta = 'inf'

        # Model prediction 
        y, anti_signal = "model prediction" # To fill 
        y, anti_signal = y.to(torch.float32), anti_signal.to(torch.float32)


        # Secondary path calculation
        st = simulator_rir.simulate(anti_signal, t60, ANSTISIGNAL_ERROR).to(device)

        ###NMSE calculation

        avg_nmse, nmse_dB = process_nmse(pt, st, anti_signal, t60, simulator_rir)
        # print(f"NMSE: {avg_nmse}, NMSE (dB): {nmse_dB}")
        
        # Store the NMSE
        nmse_list.append(nmse_dB)
        

        ### Support Bound calculation

        # Call the support_boundary_calc function
        bound, bound_dB = support_boundary_calc(t60, simulator_rir)
        # print(f"Bound: {bound}, Bound (dB): {bound_dB}")
        sup_bound_list.append(bound_dB)

        # Compute the delay between pt and y
        pt, y, delay = delay_fct(pt.cpu().numpy()[0], y.cpu().numpy()[0]) 

        y_list.append(y)



    print("NMSE (dB):", np.mean(nmse_list))

    print("Support Bound (dB):", np.mean(sup_bound_list))

    ### Information Theoretic Bound calculation
    mi_bound, mi_bound_dB = Mi_bound(pt_list, y_list, t60 )
    print(f"Mutual Information Bound: {mi_bound}, Mutual Information Bound (dB): {mi_bound_dB}")

    ### Unified Bound calculation
    print("Unified Bound (dB):", np.max([mi_bound_dB, np.mean(sup_bound_list)]))



