#!/usr/bin/env python3
"""
One-Round Per-Direction Synthetic Retraining
- Generate real data -> OLS beta0
- Take right singular vectors V of X0 as orthonormal directions
- For each direction j, generate synthetic y along v_j with mean v_j^T beta0 + noise
- Filtered pipeline keeps samples |y - v_j^T beta_prime| <= gamma*||v_j|| + sqrt(2/pi)*sigma
- Compare three estimators by squared L2 loss vs beta_star:
    0) real_only (beta0)
    1) synth_filter (per-direction mean after filtering)
    2) synth_nofilter (per-direction mean without filtering)

Outputs:
- joblib .pkl with loss grids and argmin index (which pipeline wins)
- use --out-root
"""

import os
import argparse
from pathlib import Path

import numpy as np
import joblib
from joblib import Parallel, delayed
from tqdm import tqdm

# ----------------------------
# Core simulator: one-round, per-direction retraining
# ----------------------------
def simulate_one_round_per_direction(
    a: float,
    gamma: float,
    n1: int = 500,
    n2: int = 500,
    d: int = 8,
    sigma: float = 1.0,
    raw_batch_start: int = 2000,
    raw_growth: float = 2.0,
    max_batches: int = 100,
    beta_star: np.ndarray | None = None,
    rng: np.random.Generator | None = None,
):
    """
    One Monte Carlo replicate.

    Verifier belief center: beta_prime = beta_star + a * 1_d
    Filter keep rule per direction v_j:
        |y - v_j^T beta_prime| <= gamma * ||v_j|| + sqrt(2/pi) * sigma
    """
    if rng is None:
        rng = np.random.default_rng()
    if beta_star is None:
        beta_star = np.ones(d, dtype=float)

    # ----- verifier center (belief) -----
    beta_prime = beta_star + np.full(d, a, dtype=float)

    # ----- real data -> beta0 -----
    X0 = rng.normal(size=(n1, d))
    eps0 = rng.normal(scale=sigma, size=n1)
    y0 = X0 @ beta_star + eps0
    beta0, *_ = np.linalg.lstsq(X0, y0, rcond=None)  # (d,)

    # ----- SVD directions (RIGHT singular vectors V) -----
    # X0 = U Σ V^T; directions are columns of V (orthonormal)
    Vt = np.linalg.svd(X0, full_matrices=False)[2]  # shape (d, d)
    V = Vt.T                                        # columns v_j

    # ----- Estimator 0: real-only -----
    beta_real_only = beta0.copy()

    # ----- Estimator 1: per-direction synthetic with filtering (keep exactly n2 per direction) -----
    a_coords_fil = np.zeros(d, dtype=float)
    for j in range(d):
        vj = V[:, j]
        vj_norm = float(np.sqrt(vj @ vj))  # ~1
        center = float(vj @ beta_prime)
        mean_along = float(vj @ beta0)

        kept = []
        batch = int(max(1, raw_batch_start))
        batches_used = 0
        thresh = gamma * vj_norm + np.sqrt(2.0/np.pi) * sigma

        while len(kept) < n2:
            y_raw = mean_along + rng.normal(scale=sigma, size=batch)
            mask = np.abs(y_raw - center) <= thresh
            if np.any(mask):
                kept.extend(y_raw[mask].tolist())
            if len(kept) < n2:
                batch = int(np.ceil(batch * raw_growth))
                batches_used += 1
                if batches_used > max_batches:
                    # Acceptance too low -> fallback
                    a_coords_fil[j] = float(np.mean(kept)) if kept else mean_along
                    break

        if len(kept) >= n2:
            y_kept = np.array(kept[:n2], dtype=float)
            a_coords_fil[j] = float(np.mean(y_kept))

    beta_synth_filter = V @ a_coords_fil  # recombine

    # ----- Estimator 2: per-direction synthetic without filtering (exact n2 per direction) -----
    a_coords_nof = np.zeros(d, dtype=float)
    for j in range(d):
        vj = V[:, j]
        mean_along = float(vj @ beta0)
        y_raw = mean_along + rng.normal(scale=sigma, size=n2)
        a_coords_nof[j] = float(np.mean(y_raw))
    beta_synth_nofilter = V @ a_coords_nof

    # ----- losses -----
    loss_real_only      = float(np.linalg.norm(beta_real_only      - beta_star)**2)
    loss_synth_filter   = float(np.linalg.norm(beta_synth_filter   - beta_star)**2)
    loss_synth_nofilter = float(np.linalg.norm(beta_synth_nofilter - beta_star)**2)

    return loss_real_only, loss_synth_filter, loss_synth_nofilter


# ----------------------------
# Monte Carlo average for a (bias), gamma (width)
# ----------------------------
def compute_loss_entry(
    i: int, j: int,
    bias: float, width: float,
    n1: int, n2: int, d: int, sigma: float, sim_n: int,
    raw_batch_start: int, raw_growth: float, max_batches: int,
    seed_base: int = 0,
    beta_star: np.ndarray | None = None,
):
    """
    Return averaged losses over sim_n replicates for (bias, width).
    Only three variants:
      0: real_only
      1: synth_filter
      2: synth_nofilter
    """
    if width <= 0:
        return i, j, np.nan, np.nan, np.nan, -1

    losses = np.zeros(3, dtype=float)
    for t in range(sim_n):
        rng = np.random.default_rng(seed_base + 97*i + 131*j + 17*t)
        l0, l1, l2 = simulate_one_round_per_direction(
            a=bias, gamma=width, n1=n1, n2=n2, d=d, sigma=sigma,
            raw_batch_start=raw_batch_start, raw_growth=raw_growth, max_batches=max_batches,
            beta_star=beta_star, rng=rng
        )
        losses[0] += l0
        losses[1] += l1
        losses[2] += l2

    losses /= sim_n
    best_idx = int(np.nanargmin(losses))  # 0,1,2
    return i, j, losses[0], losses[1], losses[2], best_idx


def main():
    parser = argparse.ArgumentParser(description="One-round per-direction retraining (filtered vs unfiltered vs real-only)")
    # Required/primary
    parser.add_argument("n2", type=int, help="Per-direction kept size after filtering (and no-filter)")
    # Model/data
    parser.add_argument("--d", type=int, default=8, help="Parameter dimension")
    parser.add_argument("--n1", type=int, default=100, help="Real sample size for initial OLS")
    parser.add_argument("--sigma", type=float, default=1.0, help="Noise std")
    # MC & grid
    parser.add_argument("--sim-n", type=int, default=100, help="MC replicates per grid cell")
    parser.add_argument("--n-bias", type=int, default=300)
    parser.add_argument("--n-width", type=int, default=150)
    parser.add_argument("--bias-max", type=float, default=0.25, help="Bias range: [0, bias_max]")
    parser.add_argument("--width-min", type=float, default=0.01, help="Gamma range: [width_min, width_max]")
    parser.add_argument("--width-max", type=float, default=0.5)
    # Filtering sampler control
    parser.add_argument("--raw-batch-start", type=int, default=2000)
    parser.add_argument("--raw-growth", type=float, default=2.0)
    parser.add_argument("--max-batches", type=int, default=100)
    # Infra
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n-jobs", type=int, default=-1, help="joblib parallelism")
    parser.add_argument("--out-root", type=str, default="./outputs_sim", help="Output root directory")
    parser.add_argument("--basename", type=str, default="simulation_results_lr_one_round_per_dir",
                        help="Base filename (without extension)")

    args = parser.parse_args()

    # Repro
    np.random.seed(args.seed)

    # Derived/grid
    bias_vals = np.linspace(0.0, args.bias_max, args.n_bias)
    width_vals = np.linspace(args.width_min, args.width_max, args.n_width)

    # beta_star aligned with d (default: ones)
    beta_star = np.ones(args.d, dtype=float)

    # Containers
    which_min = np.full((args.n_bias, args.n_width), -1, dtype=int)
    loss_real_only_all   = np.full((args.n_bias, args.n_width), np.nan)
    loss_synth_fil_all   = np.full((args.n_bias, args.n_width), np.nan)
    loss_synth_nofil_all = np.full((args.n_bias, args.n_width), np.nan)

    # Prepare tasks
    def _task(i, j):
        return compute_loss_entry(
            i, j, bias_vals[i], width_vals[j],
            args.n1, args.n2, args.d, args.sigma, args.sim_n,
            args.raw_batch_start, args.raw_growth, args.max_batches,
            seed_base=args.seed, beta_star=beta_star
        )

    tasks = (delayed(_task)(i, j) for i in range(args.n_bias) for j in range(args.n_width))

    # Run
    results = Parallel(n_jobs=args.n_jobs)(
        t for t in tqdm(tasks, total=args.n_bias * args.n_width,
                        desc="Simulating (one-round per-direction)", dynamic_ncols=True)
    )

    # Collect
    for res in results:
        i, j, l_real, l_fil, l_nofil, idx = res
        loss_real_only_all[i, j]   = l_real
        loss_synth_fil_all[i, j]   = l_fil
        loss_synth_nofil_all[i, j] = l_nofil
        which_min[i, j] = idx

    # Save
    out_dir = Path(args.out_root)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"{args.basename}_n2{args.n2}.pkl"

    joblib.dump({
        "which_min": which_min,         # 0:real_only, 1:synth_filter, 2:synth_nofilter
        "bias": bias_vals,
        "width": width_vals,
        "loss_real_only":     loss_real_only_all,
        "loss_synth_filter":  loss_synth_fil_all,
        "loss_synth_nofilter": loss_synth_nofil_all,
        "meta": {
            "n1": args.n1, "n2": args.n2, "d": args.d, "sigma": args.sigma,
            "sim_n": args.sim_n,
            "raw_batch_start": args.raw_batch_start,
            "raw_growth": args.raw_growth,
            "max_batches": args.max_batches,
            "seed": args.seed,
            "note": "One-round synthetic retraining, per-direction, filter vs no-filter vs real-only (clean)"
        }
    }, out_path)

    print(f"Saved results to: {out_path.resolve()}")


if __name__ == "__main__":
    main()
