import os 
import numpy as np
import torch
import torchaudio
# import gpuRIR
import pyroomacoustics as pra
import rir_generator
from scipy.signal import deconvolve, group_delay

file_dir = os.path.dirname(os.path.abspath(__file__))
fftconvolve = torchaudio.transforms.FFTConvolve(mode="same")


NOISE_REF = 0
NOISE_ERROR = 1
ANSTISIGNAL_ERROR = 2

def inverse_filter_v1(h):
    # Compute the FFT of the signal and the filter
    filter_fft = torch.fft.rfft(h)

    # Handle the case where the filter might be zero
    eps = 1e-6
    filter_fft = torch.where(filter_fft.abs() > eps, filter_fft, eps)
    return torch.fft.irfft(1 / filter_fft)

def wiener_filter(signal, rir, K=0.001):
    rir /= torch.sum(rir)
    n = signal.size(-1) - rir.size(-1) + 1
    signal_fft = torch.fft.fft(signal, n=n, dim=-1)
    rir_fft = torch.fft.fft(rir, n=n, dim=-1)
    _filter = torch.conj(rir_fft) / (torch.abs(rir_fft) ** 2 + K)
    inv_signal_fft = signal_fft * _filter
    inv_signal = torch.fft.ifft(inv_signal_fft, dim=-1).real
    return inv_signal

def weiner_wiki(signal, rir, SNR=0.00005):
    n = signal.size(-1) - rir.size(-1) + 1
    H = torch.fft.fft(rir, n=n, dim=-1)
    
    # S is the mean power spectral density of the original signal 
    signal_fft = torch.fft.fft(signal, n=n, dim=-1)
    G = (1 / H) * (1 / (1 + (1/H.abs() ** 2) * SNR))
    # G = (1 / H) * (1 / (1 + (1/(((H.abs() ** 2).T * SNR)).T)))

    inv_signal_fft = signal_fft * G
    inv_signal = torch.fft.ifft(inv_signal_fft, dim=-1).real
    return inv_signal

def _simulate(signal_batch, rir, device, padding="same"):
    signal_batch = signal_batch.to(device).unsqueeze(1)
    # signal_batch = torch.tensor(signal_batch).to(device).unsqueeze(1)

    processed_signals = torch.nn.functional.conv1d(signal_batch, rir, padding=padding)
    processed_signals = processed_signals.squeeze(1)
    return processed_signals

def _simulate_v2(signal_batch, rir, device, padding="same"):
    signal_batch = signal_batch.to(device).unsqueeze(1)
    
    # Apply the filter in the forward direction
    processed_signals = torch.nn.functional.conv1d(signal_batch, rir, padding=padding)
    
    # Reverse the filtered signal
    processed_signals = torch.flip(processed_signals, [2])
    
    # Apply the filter again in the forward direction
    processed_signals = torch.nn.functional.conv1d(processed_signals, rir, padding=padding)
    
    # Reverse the signal back to its original order
    processed_signals = torch.flip(processed_signals, [2])
    
    processed_signals = processed_signals.squeeze(1)
    return processed_signals

class RIRGenSimulator:
    def __init__(self, sr, reverbation_times, device, rir_samples=512, hp_filter=False, c=343,v=1):
        self.sr = sr
        self.device = device
        self.room_dim = [3, 4, 2]
        self.ref_mic = [1.5, 1, 1]
        self.ls_source = [1.5, 2.5, 1]
        self.error_mic = [1.5, 3, 1]
        self.reverbation_times = reverbation_times
        self.rir_length = rir_samples
        self.hp_filter = hp_filter
        self.c = c
        self.rirs = self.get_rirs()
        self.v = v

    def get_rirs(self):
        rirs = dict()
        for t60 in self.reverbation_times:
            for rir_type in [NOISE_ERROR, ANSTISIGNAL_ERROR]:
                if rir_type == NOISE_ERROR:
                    pos_src = self.ref_mic
                    pos_rcv = self.error_mic
                elif rir_type == ANSTISIGNAL_ERROR:
                    pos_src = self.ls_source
                    pos_rcv = self.error_mic
                rir = rir_generator.generate( # consider hp_filter = False
                    c=self.c,
                    fs=self.sr,
                    s=pos_src,
                    r=[pos_rcv],
                    L=self.room_dim,
                    reverberation_time=t60,
                    nsample=self.rir_length,
                    hp_filter=self.hp_filter
                )
                rirs[(t60, rir_type)] = torch.from_numpy(np.squeeze(rir)).to(self.device).view(1, 1, -1).float()
        return rirs
        

    def simulate(self, signal_batch, t60, signal_type, padding="same"):
        rir = self.rirs[(t60, signal_type)]
        # if isinstance(signal_batch, list):
        #     signal_batch = torch.tensor(signal_batch)
        # if isinstance(rir, list):
        #     rir = torch.tensor(rir)
        rir = rir.cpu()
        # print(rir.ndim, rir.type(), rir.shape)
        # print(signal_batch.ndim, signal_batch.type(), signal_batch.shape)
        # print(self.v)

        signal_batch = signal_batch.cpu()

        # print(rir.device, signal_batch.device)

        if self.v == 1:
            return _simulate(signal_batch, rir, self.device, padding)
        elif self.v == 2:
            return _simulate_v2(signal_batch, rir, self.device, padding)
        else:
            # return fftconvolve(signal_batch.squeeze(0), rir.squeeze(0))
            return fftconvolve(signal_batch, rir.squeeze(0))


class PyRoomSimulator:
    def __init__(self, sr, reverbation_times, device, rir_samples=512):
        self.sr = sr
        self.device = device
        self.room_dim = np.array([3, 4, 2])
        self.ref_mic = np.array([1.5, 1, 1])
        self.ls_source = np.array([1.5, 2.5, 1])
        self.error_mic = np.array([1.5, 3, 1])
        self.reverbation_times = reverbation_times
        self.rir_length = rir_samples
        self.rirs = self.get_rirs()

    def get_rirs(self):
        rirs = dict()

        for t60 in self.reverbation_times:
            e_absorption, max_order = pra.inverse_sabine(t60, self.room_dim)
            for rir_type in [NOISE_ERROR, ANSTISIGNAL_ERROR]:
                room = pra.ShoeBox(self.room_dim, fs=self.sr,  materials=pra.Material(e_absorption), max_order=max_order)

                if rir_type == NOISE_ERROR:
                    pos_src = self.ref_mic
                    pos_rcv = self.error_mic
                elif rir_type == ANSTISIGNAL_ERROR:
                    pos_src = self.ls_source
                    pos_rcv = self.error_mic
                room.add_source(pos_src)
                mic = pra.MicrophoneArray(pos_rcv.reshape((-1, 1)), self.sr)
                room.add_microphone_array(mic)

                room.compute_rir()
                rir = room.rir[0][0]
                # make RIR adjustments to ISM model (by pyroomacooustics maintainer https://github.com/DavidDiazGuerra/gpuRIR/issues/61)
                rir_ism = rir[40:40+self.rir_length] * (1/(torch.pi * 4))
                rirs[(t60, rir_type)] = torch.from_numpy(np.squeeze(rir_ism)).to(self.device).view(1, -1).float()
        return rirs
        

    def simulate(self, signal_batch, t60, signal_type):
        rir = self.rirs[(t60, signal_type)]
        return fftconvolve(signal_batch, rir)

class GPUSimulator:
    def __init__(self, sr, reverbation_times, device, rir_samples=512):
        self.sr = sr
        self.reverbation_times = reverbation_times
        self.device = device
        self.room_dim = np.array([3, 4, 2])
        self.ref_mic = np.array([[1.5, 1, 1]])
        self.ls_source = np.array([[1.5, 2.5, 1]])
        self.error_mic = np.array([[1.5, 3, 1]])
        self.rir_length = rir_samples / self.sr # length of the RIR in seconds
        self.nb_img = gpuRIR.t2n(T=self.rir_length, rooms_sz=self.room_dim)

        self.rirs = self.get_rirs()

    def get_rirs(self):
        rirs = dict()

        for t60 in self.reverbation_times:
            beta = gpuRIR.beta_SabineEstimation(self.room_dim, t60)

            for rir_type in [NOISE_ERROR, ANSTISIGNAL_ERROR]:
                if rir_type == NOISE_ERROR:
                    pos_src = self.ref_mic
                    pos_rcv = self.error_mic
                elif rir_type == ANSTISIGNAL_ERROR:
                    pos_src = self.ls_source
                    pos_rcv = self.error_mic
                rir = gpuRIR.simulateRIR(self.room_dim, beta, pos_src, pos_rcv, self.nb_img, Tmax=self.rir_length, fs=self.sr)
                rirs[(t60, rir_type)] = torch.from_numpy(np.squeeze(rir)).to(self.device).view(1, 1, -1)
                rirs[(t60, rir_type)].requires_grad_(False)
        return rirs
    

    def simulate(self, signal_batch, t60, signal_type):
        rir = self.rirs[(t60, signal_type)]
        return _simulate(signal_batch, rir, self.device)