#!/usr/bin/env python
# AI Summary: CLI entrypoint to run one VBMC fit reproducibly. Seeds Python/NumPy/Torch,
# writes vp.npz and metadata.json under results/sbj_xx/run_yy/. Depends on bav_model.BAVModel and pyvbmc.VBMC.

import argparse
import json
import math
import os
import random
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
from pyvbmc import VBMC

from bav_model import BAVModel


def compute_bounds():
    """Replicates hierarchical plausible bounds"""
    mean_low, std_low = 0.0, 1.5
    diff_ML, std_ML = 1.0, 1.0
    diff_MH, std_HM = 0.75, 0.5
    mean_A, std_A = 1.75, 0.5
    mean_S, std_S = 2.5, 1.0
    mean_M, std_M = 0.0, 0.5
    mean_logit, std_logit = 1.5, 1.5

    # LB 2 STDs below hierarchical means
    min_low = mean_low - 2 * std_low
    min_med = min_low - 2 * std_ML
    min_high = min_med - 2 * std_HM
    min_A = mean_A - 2 * std_A
    min_S = mean_S - 2 * std_S
    min_M = mean_M - 2 * std_M
    min_logit = mean_logit - 2 * std_logit

    # UB 2 STDs above hierarchical means
    max_low = mean_low + 2 * std_low
    max_med = max_low + 2 * std_ML
    max_high = max_med + 2 * std_HM
    max_A = mean_A + 2 * std_A
    max_S = mean_S + 2 * std_S
    max_M = mean_M + 2 * std_M
    max_logit = mean_logit + 2 * std_logit

    # PLB 1.5 STDs below hierarchical means
    Pmin_low = mean_low - 1.5 * std_low
    Pmin_med = Pmin_low - 1.5 * std_ML
    Pmin_high = Pmin_med - 1.5 * std_HM
    Pmin_A = mean_A - 1.5 * std_A
    Pmin_S = mean_S - 1.5 * std_S
    Pmin_M = mean_M - 1.5 * std_M
    Pmin_logit = mean_logit - 1.5 * std_logit

    # PUB 1.5 STDs above hierarchical means
    Pmax_low = mean_low + 1.5 * std_low
    Pmax_med = Pmax_low + 1.5 * std_ML
    Pmax_high = Pmax_med + 1.5 * std_HM
    Pmax_A = mean_A + 1.5 * std_A
    Pmax_S = mean_S + 1.5 * std_S
    Pmax_M = mean_M + 1.5 * std_M
    Pmax_logit = mean_logit + 1.5 * std_logit

    LB = np.array([min_low, min_med, min_high, min_A, min_S, min_M, min_logit], dtype=float)
    UB = np.array([max_low, max_med, max_high, max_A, max_S, max_M, max_logit], dtype=float)

    # Plausible bounds
    PLB = np.array([Pmin_low, Pmin_med, Pmin_high, Pmin_A, Pmin_S, Pmin_M, Pmin_logit], dtype=float)
    PUB = np.array([Pmax_low, Pmax_med, Pmax_high, Pmax_A, Pmax_S, Pmax_M, Pmax_logit], dtype=float)

    return LB, UB, PLB, PUB


def seed_everything(seed: int):
    """Seed Python, NumPy (global + Generator), and Torch."""
    random.seed(seed)
    np.random.seed(seed)
    rng = np.random.default_rng(seed)
    try:
        torch.manual_seed(seed)
    except Exception:
        pass
    # Keep threads low to respect scheduler; change if you reserve more cores
    try:
        torch.set_num_threads(1)
    except Exception:
        pass
    return rng


def main():
    p = argparse.ArgumentParser(description="Run a single VBMC fit for one subject/seed.")
    p.add_argument("--sbj", type=int, required=True, help="Subject index (0-based)")
    p.add_argument("--seed", type=int, required=True, help="Random seed for x0 and libraries")
    p.add_argument("--run-idx", type=int, default=0, help="Run index to name output folder (0..N-1)")
    p.add_argument("--data-path", type=str, default="bav_data.mat", help="Path to BAV .mat file")
    p.add_argument("--idx-path", type=str, default="trial_idx_400_split_1.json", help="Path to idx file")
    p.add_argument("--rho-a", type=float, default=(4.0 / 3.0), help="Auditory rescaling factor ρ")
    p.add_argument("--outdir", type=str, default="results", help="Base output directory")
    p.add_argument("--verbose", action="store_true", help="Print metadata to stdout")
    args = p.parse_args()

    rng = seed_everything(args.seed)

    LB, UB, PLB, PUB = compute_bounds()
    x0 = rng.uniform(low=PLB, high=PUB, size=PLB.shape)

    # Prepare output directory
    sbj_str = f"sbj_{args.sbj:02d}"
    run_str = f"run_{args.run_idx:02d}"
    outdir = Path(args.outdir) / sbj_str / run_str
    outdir.mkdir(parents=True, exist_ok=True)

    # Resume-safe: skip if results already exist (unless FORCE_RERUN=1)
    vp_path = outdir / "vp.npz"
    meta_path = outdir / "metadata.json"
    force = os.environ.get("FORCE_RERUN", "0").lower() in ("1", "true", "yes")
    if not force:
        skip = False
        if vp_path.exists():
            try:
                if vp_path.stat().st_size > 0:
                    skip = True
            except Exception:
                pass
        if not skip and meta_path.exists():
            try:
                meta_ok = json.loads(meta_path.read_text()).get("status") == "ok"
                if meta_ok:
                    skip = True
            except Exception:
                pass
        if skip:
            if args.verbose:
                print(json.dumps({
                    "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
                    "sbj": args.sbj,
                    "run_idx": args.run_idx,
                    "seed": args.seed,
                    "rho_a": args.rho_a,
                    "data_path": os.path.abspath(args.data_path),
                    "idx_path": os.path.abspath(args.idx_path),
                    "outdir": str(outdir),
                    "skipped": True,
                    "reason": "existing results"
                }, indent=2))
            return

    # Instantiate model
    model = BAVModel(sbj=args.sbj, RHO_A=args.rho_a, data_path=args.data_path, idx_path=args.idx_path, truncated=True)

    status = "ok"
    err_msg = ""
    vp_path = outdir / "vp.npz"

    try:
        vbmc = VBMC(
            log_density=model.log_joint,
            lower_bounds=LB,
            upper_bounds=UB,
            plausible_lower_bounds=PLB,
            plausible_upper_bounds=PUB,
            x0=x0,
        )
        vp, _ = vbmc.optimize()

        # Save variational posterior
        try:
            vp.save(str(vp_path))
        except TypeError:
            # Some pyvbmc versions accept no path; fall back and keep going
            vp.save()
    except Exception as e:
        status = "error"
        err_msg = repr(e)
        with open(outdir / "error.txt", "w") as f:
            f.write(err_msg + "\n")

    # Write metadata
    meta = {
        "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
        "sbj": args.sbj,
        "run_idx": args.run_idx,
        "seed": args.seed,
        "rho_a": args.rho_a,
        "data_path": os.path.abspath(args.data_path),
        "outdir": str(outdir),
        "files": {"vp": str(vp_path), "error": str(outdir / "error.txt")},
        "bounds": {
            "LB": LB.tolist(),
            "UB": UB.tolist(),
            "PLB": PLB.tolist(),
            "PUB": PUB.tolist(),
        },
        "x0": x0.tolist(),
        "status": status,
        "error": err_msg,
        "versions": {
            "python": sys.version,
            "numpy": np.__version__,
            "torch": getattr(torch, "__version__", "unknown"),
        },
    }
    with open(outdir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2)

    if args.verbose:
        print(json.dumps(meta, indent=2))

    if status != "ok":
        sys.exit(1)


if __name__ == "__main__":
    main()
    