#!/usr/bin/env python3
"""
Generate normalized EEG batches with no trial repetition within an epoch.
Uses global statistics to normalize each channel independently.
"""

import os
import json
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from src.data.eeg_sampler import EEGSampler

def create_normalized_batches(
    subset: str,
    nc_value: int,
    nb_value: int = 8,
    num_epochs: int = 50,
    batch_size: int = 32,
    mode: str = "interpolation", # interpolation | forecasting
    data_root_dir: str = None,
    output_dir: str = None,
    seed: int = 42
):
    """
    Create normalized batches for multiple epochs with no trial repetition.
    
    Args:
        subset: 'train', 'cv', or 'eval'
        nc_value: Fixed number of context points (e.g., 192 for easy)
        nb_value: Number of buffer points (default: 8)
        num_epochs: Number of epochs to generate
        batch_size: Batch size (must be 32)
        output_dir: Output directory name
        seed: Random seed
    """
    assert mode in ["interpolation", "forecasting"], "Mode must be 'interpolation' or 'forecasting'"
    
    # Load normalization statistics
    if data_root_dir is None:
        data_root_dir = "data"
    dt_rp = Path(data_root_dir)
    with open(dt_rp / "eeg_normalization_stats.json", "r") as f:
        stats = json.load(f)
    
    channel_means = np.array(stats["channel_means"])
    channel_stds = np.array(stats["channel_stds"])
    
    print(f"Using normalization statistics from {stats['num_trials']} trials")
    print(f"Channel means: {channel_means.mean():.4f} ± {channel_means.std():.4f}")
    print(f"Channel stds: {channel_stds.mean():.4f} ± {channel_stds.std():.4f}")
    
    if output_dir is None:
        output_dir = f"data/eeg_batches_normalized"
    
    # Create output directories
    output_path = Path(output_dir) / f"{nc_value}con_{256 - nc_value - nb_value}tar_{mode}_{subset}"
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Initialize sampler to get trials
    sampler = EEGSampler(
        data_path=dt_rp/"eeg",
        subset=subset,
        mode=mode,
        batch_size=batch_size,
        num_tasks=1000,  # Not used, but required
        total_points=256,  # Full trial length
        device="cpu",
        dtype=torch.float32,
        seed=seed
    )
    
    # Get total number of trials
    num_trials = len(sampler.trials)
    batches_per_epoch = num_trials // batch_size
    
    print(f"\n{subset.upper()} subset:")
    print(f"  Total trials: {num_trials}")
    print(f"  Batches per epoch: {batches_per_epoch}")
    print(f"  Trials used per epoch: {batches_per_epoch * batch_size}")
    print(f"  Trials dropped per epoch: {num_trials % batch_size}")
    print(f"  Generating {num_epochs} epochs...")
    
    # Fixed splits
    nb = nb_value  # Buffer size
    nt = 256 - nc_value - nb  # Target points
    
    # Track all generated batches
    all_batches = []
    
    # RNG for reproducibility
    rng = np.random.RandomState(seed)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Shuffle trials for this epoch
        epoch_trials = sampler.trials.copy()
        rng.shuffle(epoch_trials)
        
        # Create batches for this epoch
        for batch_idx in tqdm(range(batches_per_epoch), desc=f"Creating batches"):
            # Get trials for this batch
            batch_trials = epoch_trials[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            
            # Prepare batch tensors
            batch_xc = []
            batch_yc = []
            batch_xb = []
            batch_yb = []
            batch_xt = []
            batch_yt = []
            
            for trial in batch_trials:
                # Get trial data
                time = trial['time']
                data = trial['data'].values  # (256, 7)
                
                # NORMALIZE: Subtract mean and divide by std for each channel
                data_normalized = (data - channel_means) / channel_stds
                
                # Use actual time values (0 to 1 second), not normalized
                x = torch.tensor(time, dtype=torch.float32).unsqueeze(1)  # (256, 1)
                y = torch.tensor(data_normalized, dtype=torch.float32).T  # (7, 256)
                
                if mode == "interpolation":
                    # Random permutation for this trial
                    perm = torch.randperm(256)
                elif mode == "forecasting":
                    # Sort time to get sequential indices
                    perm = torch.argsort(x.squeeze())  # (256,)
                else:
                    raise ValueError("Invalid mode")
                
                # Split using permuted indices
                context_indices = perm[:nc_value]
                buffer_indices = perm[nc_value:nc_value + nb]
                target_indices = perm[nc_value + nb:]
                
                # Gather points in scrambled order
                xc = x[context_indices]  # (nc, 1)
                yc = y[:, context_indices].T  # (nc, 7)
                
                xb = x[buffer_indices]  # (nb, 1)
                yb = y[:, buffer_indices].T  # (nb, 7)
                
                xt = x[target_indices]  # (nt, 1)
                yt = y[:, target_indices].T  # (nt, 7)
                
                batch_xc.append(xc)
                batch_yc.append(yc)
                batch_xb.append(xb)
                batch_yb.append(yb)
                batch_xt.append(xt)
                batch_yt.append(yt)
            
            # Stack into batch tensors
            batch_data = {
                'xc': torch.stack(batch_xc),  # (32, nc, 1)
                'yc': torch.stack(batch_yc),  # (32, nc, 7)
                'xb': torch.stack(batch_xb),  # (32, nb, 1)
                'yb': torch.stack(batch_yb),  # (32, nb, 7)
                'xt': torch.stack(batch_xt),  # (32, nt, 1)
                'yt': torch.stack(batch_yt),  # (32, nt, 7)
            }
            
            # Add mask placeholder (computed during training)
            batch_data['mask'] = torch.zeros(1)
            
            all_batches.append(batch_data)
    
    # Save all batches
    print(f"\nSaving {len(all_batches)} total batches...")
    for i, batch in enumerate(tqdm(all_batches)):
        torch.save(batch, output_path / f"batch_{i:06d}.pt")
    
    # Save metadata
    metadata = {
        "num_epochs": num_epochs,
        "batches_per_epoch": batches_per_epoch,
        "num_batches": len(all_batches),
        "batch_size": batch_size,
        "nc": nc_value,
        "nb": nb,
        "nt": nt,
        "total_points": 256,
        "dim_x": 1,
        "dim_y": 7,
        "dtype": str(torch.float32),
        "keys": list(batch_data.keys()),
        "subset": subset,
        "mode": mode,
        "nc_fixed": nc_value,
        "num_trials": num_trials,
        "trials_per_epoch": batches_per_epoch * batch_size,
        "trials_dropped_per_epoch": num_trials % batch_size,
        "normalized": True,
        "normalization_stats": {
            "channel_means": channel_means.tolist(),
            "channel_stds": channel_stds.tolist()
        }
    }
    
    with open(output_path / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Saved metadata to {output_path}/metadata.json")
    print(f"Dataset complete: {output_dir}/{subset}/")
    return metadata


def main():
    # Configuration for EASY dataset
    nc_easy = 192
    nb_easy = 0
    mode = "forecasting" # interpolation | forecasting
    num_epochs = 1  # Increased from 50 to 200 for no repetition during 100-epoch training
    
    # Generate normalized easy dataset (nc=192)
    print("=" * 60)
    print(f"Generating NORMALIZED EASY dataset (nc={nc_easy}, {num_epochs} epochs)")
    print("=" * 60)
    
    for subset in ["eval"]:
        # Note: using "val" directly instead of "cv"
        sampler_subset = "cv" if subset == "val" else subset
        create_normalized_batches(
            subset=sampler_subset,
            nc_value=nc_easy,
            nb_value=nb_easy,
            num_epochs=num_epochs,
            mode=mode,
            data_root_dir="data",
            output_dir="data/eeg_dataset",
            seed=42 if subset == "train" else 43
        )
    
    print("\n" + "=" * 60)
    print("Normalized dataset generated successfully!")
    print("Data has mean≈0 and std≈1 per channel")
    print("=" * 60)


if __name__ == "__main__":
    main()