import torch
import torch.nn as nn
from sklearn.decomposition import FastICA, NMF
import librosa
import numpy as np

class FastICA_NMF_Separation(nn.Module):
    def __init__(
        self, speaker_num, separation_type
    ):
        super().__init__()
        self.speaker_num = speaker_num
        self.separation_type = separation_type

    def separate_ICA(self, spectrogram):
        # Compute ICA
        magnitude = np.abs(spectrogram)
        S = magnitude.T  # Transpose to get (time x frequency)
        
        # Perform ICA
        ica = FastICA(n_components=self.speaker_num, random_state=42)
        S_ica = ica.fit_transform(S)  # Extract independent components

        reconstructed = []
        for source in S_ica.T:
            # Combine magnitude and original phase
            source_stft = source * np.exp(1j * np.angle(spectrogram))
            reconstructed.append(librosa.istft(source_stft))

        return torch.from_numpy(np.array(reconstructed))

    def separate_NMF(self, spectrogram):
        # Compute NMF
        magnitude = np.abs(spectrogram)
        
        # Perform NMF
        model = NMF(n_components=self.speaker_num, max_iter=400, random_state=42)
        W = model.fit_transform(magnitude)  # Basis matrix
        H = model.components_  # Activation matrix

        reconstructed = []
        for i in range(W.shape[1]):  # Number of components
            source_mag = np.outer(W[:, i], H[i, :])
            source_stft = source_mag * np.exp(1j * np.angle(spectrogram))
            reconstructed.append(librosa.istft(source_stft))
        
        return torch.from_numpy(np.array(reconstructed))

    def forward(
        self,
        mixture: torch.Tensor,
    ) -> torch.Tensor:
        """Forward.

        """

        mixture = mixture.cpu().numpy().squeeze()

        spectrogram = librosa.stft(mixture, n_fft=256, hop_length=64)

        if self.separation_type == "ICA":
            output = self.separate_ICA(spectrogram)
        elif self.separation_type == "NMF":
            output = self.separate_NMF(spectrogram)
        else:
            raise NotImplementedError(self.separation_type)

        return output

