
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

def get_pseudolabels(x, D, is_mixed, label_type='coherence', sampling_rate=128, temperature=1.0):
    if is_mixed:
        if label_type == 'coherence':
            label = coherence_pseudolabels_band(x, D, low_freq=8, high_freq=13, sampling_rate=sampling_rate, temperature=temperature) + coherence_pseudolabels_band(x, D, low_freq=13, high_freq=30, sampling_rate=sampling_rate, temperature=temperature) + coherence_pseudolabels_band(x, D, low_freq=30, high_freq=100, sampling_rate=sampling_rate, temperature=temperature)
        else:
            raise NotImplementedError
        label = label / 3.
    else:
        if label_type == 'coherence':
            label = coherence_pseudolabels_band(x, D, sampling_rate=sampling_rate, temperature=temperature)
        elif label_type == 'all_coherence':
            label = coherence_pseudolabels(x, D, temperature=temperature)
        else:
            raise NotImplementedError
    return label
    
def coherence_pseudolabels(x, D, temperature=1.0):
    """
    Calculate coherence across EEG channels.
    """
    B, C, T = x.shape
    T = T // D
    
    x_fft = torch.fft.fft(x, dim=2)  # [B, C, T*D]
    x_fft_conj = torch.conj(x_fft)  # [B, C, T*D]
    
    x_fft_expanded = x_fft.unsqueeze(2)  # [B, C, 1, T*D]
    x_fft_conj_expanded = x_fft_conj.unsqueeze(1)  # [B, 1, C, T*D]
    
    cross_spectral = x_fft_expanded * x_fft_conj_expanded
    
    cross_spectral_avg = cross_spectral.mean(dim=3)  # [B, C, C]
    
    power_spectral = torch.diagonal(cross_spectral_avg, dim1=1, dim2=2)  # [B, C]
    
    # Coherence(i,j) = |cross_spectral(i,j)|^2 / (power(i) * power(j))
    power_outer = torch.bmm(power_spectral.unsqueeze(-1), power_spectral.unsqueeze(1))  # [B, C, C]
    
    coherence_matrices = torch.abs(cross_spectral_avg) ** 2 / (torch.abs(power_outer) + 1e-8)  # [B, C, C]
    
    idx = F.softmax(coherence_matrices / temperature, dim=-1)
    
    return idx.view(B * C, 1, C).repeat(1, T, 1)

def coherence_pseudolabels_band(x, D, temperature=1.0, low_freq=8, high_freq=100, sampling_rate=128):
    """
    Calculate band-limited coherence across EEG channels.
    """
    B, C, T = x.shape
    T = T // D
    
    x_fft = torch.fft.fft(x, dim=2)  # [B, C, T]
    
    freqs = torch.fft.fftfreq(T*D, d=1/sampling_rate, device=x.device)
    
    freq_mask = (torch.abs(freqs) >= low_freq) & (torch.abs(freqs) <= high_freq)
    
    x_fft_band = x_fft[:, :, freq_mask]  # [B, C, N_freq_bins]
    N_freq = x_fft_band.shape[2]
    
    if N_freq == 0:
        assert False
    
    # Cross-spectral density between channels i and j: X_i * conj(X_j)
    x_fft_conj = torch.conj(x_fft_band)  # [B, C, N_freq]
    
    x_fft_expanded = x_fft_band.unsqueeze(2)  # [B, C, 1, N_freq]
    x_fft_conj_expanded = x_fft_conj.unsqueeze(1)  # [B, 1, C, N_freq]
    
    # Cross-Power spectral density (CPSD)
    cross_spectral = x_fft_expanded * x_fft_conj_expanded
    
    # Average across frequency bins within the band to get cross-spectral density matrix
    cross_spectral_avg = cross_spectral.mean(dim=3)  # [B, C, C]
    
    # Power spectral densities (PSD)
    power_spectral = torch.diagonal(cross_spectral_avg, dim1=1, dim2=2)  # [B, C]
    
    # Coherence(i,j) = |cross_spectral(i,j)|^2 / (power(i) * power(j))
    power_outer = torch.bmm(power_spectral.unsqueeze(-1), power_spectral.unsqueeze(1))  # [B, C, C]
    
    coherence_matrices = torch.abs(cross_spectral_avg) ** 2 / (torch.abs(power_outer) + 1e-8)  # [B, C, C]
    
    idx = F.softmax(coherence_matrices / temperature, dim=-1)

    return idx.view(B * C, 1, C).repeat(1, T, 1)