#!/usr/bin/env python3
"""
Compute global statistics for EEG data normalization.
Computes mean and std per channel across all trials.
"""

import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from src.data.eeg_sampler import EEGSampler

def compute_global_stats():
    """Compute mean and std per channel across all trials."""
    
    # We'll collect statistics from the training set only
    sampler = EEGSampler(
        data_path="data/eeg",
        subset="train",
        mode="interpolation",
        batch_size=32,
        num_tasks=1000,
        total_points=256,
        device="cpu",
        dtype=torch.float32,
        seed=42
    )
    
    print(f"Computing statistics from {len(sampler.trials)} training trials...")
    
    # Initialize accumulators for each channel
    channel_sums = np.zeros(7)
    channel_sq_sums = np.zeros(7)
    total_points = 0
    
    # Process each trial
    for trial in tqdm(sampler.trials, desc="Computing statistics"):
        # Get trial data - shape (256, 7)
        data = trial['data'].values
        
        # Accumulate statistics
        channel_sums += data.sum(axis=0)
        channel_sq_sums += (data ** 2).sum(axis=0)
        total_points += len(data)
    
    # Compute mean and std per channel
    channel_means = channel_sums / total_points
    channel_vars = (channel_sq_sums / total_points) - (channel_means ** 2)
    channel_stds = np.sqrt(channel_vars)
    
    print("\n=== Global Statistics (per channel) ===")
    channels = ["FZ", "F1", "F2", "F3", "F4", "F5", "F6"]
    for i, ch in enumerate(channels):
        print(f"{ch}: mean={channel_means[i]:.4f}, std={channel_stds[i]:.4f}")
    
    print(f"\nOverall mean: {channel_means.mean():.4f}")
    print(f"Overall std: {channel_stds.mean():.4f}")
    
    # Save statistics
    stats = {
        "channel_means": channel_means.tolist(),
        "channel_stds": channel_stds.tolist(),
        "channels": channels,
        "num_trials": len(sampler.trials),
        "subset": "train"
    }
    
    import json
    with open("data/eeg_normalization_stats.json", "w") as f:
        json.dump(stats, f, indent=2)
    
    print("\nStatistics saved to data/eeg_normalization_stats.json")
    return channel_means, channel_stds

if __name__ == "__main__":
    compute_global_stats()