#!/usr/bin/env python3
"""
Generate normalized EEG batches (200 epochs) with varying context sizes per batch.

Each batch uses a single context size from a fixed set (default: 8,16,32,64,128,192),
and targets are computed as nt = 256 - nc - nb with nb=8.

Normalization uses global per-channel mean/std from data/eeg_normalization_stats.json
(compute with compute_eeg_stats.py). Raw EEG must be downloaded first
(download_eeg_data.py), which stores data under data/eeg/full/ and builds a cache.

Output structure (by default):
  data/eeg/ncvar_200e/{train,val}/
    - batch_000000.pt, ...
    - metadata.json
"""

import json
from pathlib import Path
from typing import List

import numpy as np
import torch
from tqdm import tqdm

from src.data.eeg_sampler import EEGSampler


def create_normalized_batches_varnc(
    subset: str,
    nc_values: List[int],
    num_epochs: int = 200,
    batch_size: int = 32,
    num_buffer: int = 8,
    output_dir: str = "data/eeg/ncvar_200e",
    seed: int = 42,
):
    """Create normalized batches for multiple epochs with varying context sizes.

    Args:
        subset: 'train' or 'val' (val uses 'cv' split internally)
        nc_values: List of context sizes to cycle through per batch
        num_epochs: Number of epochs to generate (default 200)
        batch_size: Batch size (must be 32 for EEG datasets here)
        output_dir: Output directory root
        seed: Random seed for reproducibility
    """

    # Load normalization statistics
    stats_path = Path("data/eeg_normalization_stats.json")
    if not stats_path.exists():
        raise FileNotFoundError(
            f"Normalization stats not found at {stats_path}. Run compute_eeg_stats.py first."
        )
    with open(stats_path, "r") as f:
        stats = json.load(f)
    channel_means = np.array(stats["channel_means"])  # (7,)
    channel_stds = np.array(stats["channel_stds"])    # (7,)

    # Prepare output
    out_split = "cv" if subset == "val" else subset
    out_path = Path(output_dir) / f"{subset}_{num_buffer}buffer"
    out_path.mkdir(parents=True, exist_ok=True)

    # Initialize sampler for trial harvesting (uses raw EEG and caching)
    sampler = EEGSampler(
        data_path="data/eeg",
        subset=out_split,
        mode="interpolation",
        batch_size=batch_size,
        num_tasks=1000,  # placeholder, not used in this offline generation
        total_points=256,
        device="cpu",
        dtype=torch.float32,
        seed=seed,
    )

    num_trials = len(sampler.trials)
    batches_per_epoch = num_trials // batch_size
    nb = num_buffer

    print(f"\n{subset.upper()} subset:")
    print(f"  Total trials: {num_trials}")
    print(f"  Batches per epoch: {batches_per_epoch}")
    print(f"  Trials used per epoch: {batches_per_epoch * batch_size}")
    print(f"  Trials dropped per epoch: {num_trials % batch_size}")
    print(f"  Generating {num_epochs} epochs with nc in {nc_values}...")

    # Deterministic RNG for shuffling per epoch
    rng = np.random.RandomState(seed)

    # Generate and save batches
    all_batches = 0
    batch_global_idx = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Shuffle trials for this epoch
        epoch_trials = sampler.trials.copy()
        rng.shuffle(epoch_trials)

        for b in tqdm(range(batches_per_epoch), desc="Creating batches"):
            # Choose nc by cycling through list with epoch offset for balance
            nc = nc_values[(b + epoch) % len(nc_values)]
            nt = 256 - nc - nb
            if nt <= 0:
                # Skip impossible configuration
                continue

            # Trials for this batch
            batch_trials = epoch_trials[b * batch_size : (b + 1) * batch_size]

            # Prepare batch tensors
            batch_xc, batch_yc = [], []
            batch_xb, batch_yb = [], []
            batch_xt, batch_yt = [], []

            for trial in batch_trials:
                time = trial["time"]                    # (256,)
                data = trial["data"].values             # (256, 7)

                # Normalize per-channel
                data_norm = (data - channel_means) / channel_stds

                # Use actual time values (0..1s)
                x = torch.tensor(time, dtype=torch.float32).unsqueeze(1)   # (256, 1)
                y = torch.tensor(data_norm, dtype=torch.float32).T         # (7, 256)

                # Random permutation; split into context, buffer, target
                perm = torch.randperm(256)
                idx_c = perm[:nc]
                idx_b = perm[nc : nc + nb]
                idx_t = perm[nc + nb :]

                xc = x[idx_c]
                yc = y[:, idx_c].T
                xb = x[idx_b]
                yb = y[:, idx_b].T
                xt = x[idx_t]
                yt = y[:, idx_t].T

                batch_xc.append(xc)
                batch_yc.append(yc)
                batch_xb.append(xb)
                batch_yb.append(yb)
                batch_xt.append(xt)
                batch_yt.append(yt)

            # Stack tensors
            batch = {
                "xc": torch.stack(batch_xc),
                "yc": torch.stack(batch_yc),
                "xb": torch.stack(batch_xb),
                "yb": torch.stack(batch_yb),
                "xt": torch.stack(batch_xt),
                "yt": torch.stack(batch_yt),
                # Optional metadata for consumers
                "nc": int(nc),
                "nb": int(nb),
                "nt": int(nt),
                "mask": torch.zeros(1),
            }

            # Save batch
            torch.save(batch, out_path / f"batch_{batch_global_idx:06d}.pt")
            batch_global_idx += 1
            all_batches += 1

    # Save metadata for this split
    metadata = {
        "num_epochs": num_epochs,
        "batches_per_epoch": batches_per_epoch,
        "num_batches": all_batches,
        "batch_size": batch_size,
        "nc_values": list(nc_values),
        "nb": nb,
        "total_points": 256,
        "dim_x": 1,
        "dim_y": 7,
        "dtype": str(torch.float32),
        "keys": ["xc", "yc", "xb", "yb", "xt", "yt", "mask", "nc", "nb", "nt"],
        "subset": subset,
        "mode": "interpolation",
        "normalized": True,
        "normalization_stats": stats,
        "note": "Context size varies per batch; targets = 256 - nc - 8.",
    }
    with open(out_path / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Saved metadata to {out_path}/metadata.json")
    print(f"Dataset complete: {output_dir}/{subset}/")
    return metadata


def main():
    # Default nc set and epochs to mirror the fixed-nc dataset scale
    nc_values = [8, 16, 32, 64, 128, 192]
    num_epochs = 200

    print("=" * 60)
    print("Generating NORMALIZED variable-context dataset (200 epochs)")
    print("Contexts per batch from:", nc_values)
    print("Output root:", "data/eeg/eeg_dataset")
    print("=" * 60)

    for subset in ["train", "val"]:
        seed = 42 if subset == "train" else 43
        create_normalized_batches_varnc(
            subset=subset,
            nc_values=nc_values,
            num_epochs=num_epochs,
            batch_size=32,
            num_buffer=32,
            output_dir="data/eeg_dataset",
            seed=seed,
        )

    print("\n" + "=" * 60)
    print("Variable-context normalized dataset generated successfully!")
    print("Data has mean≈0 and std≈1 per channel; nc varies per batch.")
    print("=" * 60)


if __name__ == "__main__":
    main()

