import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Determine project root as two levels above this file
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '', '..'))
DATA_DIR = os.path.join(PROJECT_ROOT, "openml_datasets")
def load_openml_processed(file_path, label_column='label'):
    df = pd.read_csv(file_path)
    X = df.drop(columns=[label_column])
    y = df[label_column]
    return X, y

def compute_spectrum(x, kmax=200, nk=1000):
    """
    Compute magnitude of (regular) DFT for a 1D signal.
    Returns: freq, |DFT|
    """
    n = len(x)
    x = x - x.mean()  # Center
    Xf = np.fft.fft(x, n=n)
    freqs = np.fft.fftfreq(n)
    # Use only positive frequencies
    idx = np.argsort(freqs)
    freqs = freqs[idx]
    Xf = Xf[idx]
    return freqs[:n//2], np.abs(Xf[:n//2])

def select_topk_var_features(X, k=10):
    # Returns indices of top-k features with highest variance
    variances = X.var(axis=0)
    return np.argsort(-variances)[:k]

if __name__ == "__main__":
    dataset = 'default-of-credit-card-clients_categorical'
    data_dir = f'openml_datasets/{dataset}'
    train_data_file = os.path.join(data_dir, 'train_data.csv')

    X, y = load_openml_processed(train_data_file, label_column='label')
    x_tensor = torch.tensor(X.values, dtype=torch.float32)
    print("Shape:", x_tensor.shape, "Labels:", y.shape)

    # Spectral analysis of top-k variance features
    topk = 10
    topk_indices = select_topk_var_features(X.values, k=topk)

    spectra = []
    for idx in topk_indices:
        freqs, mag = compute_spectrum(X.values[:, idx])
        spectra.append(mag)
    avg_spectrum = np.mean(np.stack(spectra), axis=0)

    # Plot
    plt.figure(figsize=(8, 5))
    for i, mag in enumerate(spectra):
        plt.plot(freqs, mag, alpha=0.3, label=f"Feature {topk_indices[i]}" if i == 0 else None)
    plt.plot(freqs, avg_spectrum, color='red', linewidth=2, label="Average Spectrum (top-k)")
    plt.yscale('log')
    plt.xlabel("Frequency")
    plt.ylabel("Magnitude (log scale)")
    plt.title(f"Average DFT Spectrum for Top-{topk} Variance Features\n{dataset}")
    plt.legend()
    plt.tight_layout()

    save_dir = "spectral_plots"
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f"spectrum_{dataset}.png"), dpi=200)
    plt.show()
