#!/usr/bin/env python3
"""
Generate Monte Carlo datasets for covariance simulations.

- n in {250, 500, 1250, 2500}
- seeds 1..200 for each n
- Uses parallel computing with multiprocessing + fork
"""

import os
import itertools
import numpy as np
import torch
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from generate_lap import generate_lap_dataset


# ----------------- GLOBAL SETTINGS -----------------
q = 10
m = 1000
SAMPLE_SIZES = [500, 1250, 2500]
NUM_REP = 200

OUT_DIR = "covariance_data"
os.makedirs(OUT_DIR, exist_ok=True)

MAX_WORKERS = os.cpu_count() - 1 # adjust if you want fewer workers


# ----------------- WORKER FUNCTION -----------------
def generate_and_save_dataset(n: int, seed: int):
    """
    Generate one dataset for given (n, seed) and save to disk.
    Returns (n, seed, 'ok') if successful.
    Raises on error.
    """
    n_dir = os.path.join(OUT_DIR, f"n{n}")
    os.makedirs(n_dir, exist_ok=True)

    prefix = os.path.join(n_dir, f"seed{seed:03d}")
    x_path = prefix + "_X.npy"
    m_path = prefix + "_M.npy"
    c_path = prefix + "_C.npy"

    # Skip if already exists (so you can safely rerun)
    if os.path.exists(x_path) and os.path.exists(m_path) and os.path.exists(c_path):
        print(f"[n={n}] seed={seed:03d} already exists, skipping")
        return (n, seed, "exists")

    # ---- generate data ----
    device = torch.device('cpu')
    X, M_list, cond_means = generate_lap_dataset(
        n=n,
        m=m,
        q=q,
        noise_level=0.02,
        edge_prob=0.3,
        random_seed=seed,
        device=device,
        dtype=torch.float32,
    )

    # detach & move to CPU numpy
    X_np = X.detach().cpu().numpy()
    M_np = np.stack([M_list[i].detach().cpu().numpy() for i in range(len(M_list))], axis=0)
    C_np = np.stack([cond_means[i].detach().cpu().numpy() for i in range(len(cond_means))], axis=0)

    # ---- save ----
    np.save(x_path, X_np)
    np.save(m_path, M_np)
    np.save(c_path, C_np)

    print(f"[n={n}] seed={seed:03d} done")
    return (n, seed, "ok")


# ----------------- MAIN DRIVER -----------------
def main():
    print(f"Using up to {MAX_WORKERS} parallel workers")
    print(f"Sample sizes: {SAMPLE_SIZES}")
    print(f"Replications per n: {NUM_REP}")
    print(f"Output dir: {OUT_DIR}")

    # build all (n, seed) jobs
    jobs = [(n, seed) for n in SAMPLE_SIZES for seed in range(1, NUM_REP + 1)]
    print(f"Total jobs: {len(jobs)}")

    errors = []

    with ProcessPoolExecutor(max_workers=MAX_WORKERS) as ex:
        future_to_job = {
            ex.submit(generate_and_save_dataset, n, seed): (n, seed)
            for (n, seed) in jobs
        }

        for fut in as_completed(future_to_job):
            n, seed = future_to_job[fut]
            try:
                _n, _seed, status = fut.result()
            except Exception as e:
                msg = f"!! Error in job n={n}, seed={seed:03d}: {e}"
                print(msg)
                errors.append((n, seed, str(e)))

    print("\n=== Summary ===")
    if errors:
        print(f"Errors in {len(errors)} jobs:")
        # print only first few to avoid giant output
        for (n, seed, msg) in errors[:10]:
            print(f"  n={n}, seed={seed:03d}: {msg}")
        if len(errors) > 10:
            print(f"  ... and {len(errors) - 10} more")
    else:
        print("All datasets generated successfully!")

    print("Done.")


if __name__ == "__main__":
    # Force fork on Linux to avoid spawn/<stdin> issues
    try:
        mp.set_start_method("fork", force=True)
    except RuntimeError:
        # start method was already set (e.g., if cluster sets it); ignore
        pass

    main()
