import torch 
import numpy as np
from scipy.signal import correlate


def get_device(id=None):
    gpu_id = id
    device = "cuda" if id is None else f"cuda:{gpu_id}"
    return device

def process_nmse(pt, st, anti_signal, t60, simulator_rir):
    pt = pt.view(pt.size(0), -1)
    st = st.view(st.size(0), -1)
    # anti_signal = anti_signal.view(anti_signal.size(0), -1)

    # make sure pt and st have the same size 
    if pt.size(1) > st.size(1):
        pt = pt[:, :st.size(1)]
    elif pt.size(1) < st.size(1):
        st = st[:, :pt.size(1)]

    nmse_ = ((pt - st) ** 2).sum(dim=1) / (pt ** 2).sum(dim=1)
    avg_nmse = nmse_.mean().item()

    nmse_dB = 10 * np.log10(avg_nmse)

    return avg_nmse, nmse_dB


def delay_fct(signal1, signal2):
    """
    Compute the delay between two signals using cross-correlation.

    Args:
        signal1 (numpy.ndarray): The first signal.
        signal2 (numpy.ndarray): The second signal.

    Returns:
        signal1 (numpy.ndarray): The first signal normalized.
        signal2_delayed (numpy.ndarray): The second signal delayed.
        delay (int): The delay applied to the second signal.
    """
    ### Normalize the signals 
    epsilon = 1e-10 

    signal1 = signal1 / (np.sqrt(np.sum(signal1**2)) + epsilon)
    signal2 = signal2 / (np.sqrt(np.sum(signal2**2)) + epsilon)

    ### Compute the cross-correlation between the two signals      
    correlation = correlate(signal1, signal2, mode='full')
    delay = np.argmax(np.abs(correlation)) - len(signal1) + 1

    # Apply the delay to the second signal by rooling it 
    signal2_delayed = np.roll(signal2, delay)

    return signal1, signal2_delayed, delay