"""
Run a simple synthetic experiment for Factorized Scheduling Principle Framework.

Dependencies: numpy, matplotlib
"""

from __future__ import annotations

from pathlib import Path
import numpy as np

from factorized_model import FactorizedSchedulingModel
from synthetic_env import (
    S_star,
    phi1_star, phi2_star, phi3_star, phi4_star,
    psi12_star,
    evaluate_function_fit,
    transfer_test,
    plot_components_1d,
    plot_pair_slice,
    plot_interaction_3d,
)


def main() -> None:
    # -------------------------
    # Reproducibility
    # -------------------------
    seed = 0
    rng = np.random.default_rng(seed)

    # -------------------------
    # Model / training config
    # -------------------------
    K = 4
    interactions = {(0, 1)}

    model = FactorizedSchedulingModel(
        K=K,
        M_per_dim=30,
        interactions=interactions,
        M_pair=15,
        lr=0.03,
        l2=1e-4,
        seed=seed,
    )

    steps = 30_000
    slate_N = 8

    # feedback noise
    sigma = 0.03

    gamma = 0.0
    lam_val = 0.1

    # optional smoothness
    alpha_phi, beta_phi = 1e-3, 1e-4
    alpha_psi, beta_psi = 1e-3, 1e-4

    # output
    out_dir = Path("figs_synthetic")
    out_dir.mkdir(parents=True, exist_ok=True)

    # -------------------------
    # Training loop
    # -------------------------
    for _ in range(steps):
        X = rng.uniform(0.0, 1.0, size=(slate_N, K))
        a = int(rng.integers(slate_N))
        r = float(S_star(X[a:a + 1])[0] + rng.normal(0.0, sigma))
        X_next = rng.uniform(0.0, 1.0, size=(slate_N, K))

        model.train_step(
            X, a, r, X_next,
            gamma=gamma,
            lam_val=lam_val,
            alpha_phi=alpha_phi,
            beta_phi=beta_phi,
            alpha_psi=alpha_psi,
            beta_psi=beta_psi,
        )

    # -------------------------
    # Evaluation (fit + ranking)
    # -------------------------
    rmse, rank_acc = evaluate_function_fit(
        model,
        S_star,
        samples_L2=30_000,
        trials_rank=3_000,
        slate_N=slate_N,
        seed=1,
    )

    print(f"Eval@N={slate_N}: RMSE={rmse:.6f}")
    print(f"Eval@N={slate_N}: Pairwise ranking accuracy={rank_acc:.4f}")

    # -------------------------
    # Transfer across slate sizes
    # -------------------------
    results = transfer_test(model, S_star, Ns=(4, 8, 16, 32), trials=800, seed=2)
    for N, stats in results.items():
        print(f"Transfer N={N}: {stats}")

    # -------------------------
    # Plots (optional diagnostics)
    # -------------------------
    plot_components_1d(
        model,
        phi_true_fns=[phi1_star, phi2_star, phi3_star, phi4_star],
        save_dir=str(out_dir),
        prefix="k4",
    )

    plot_pair_slice(
        model,
        0, 1,
        psi_true_fn=psi12_star,
        save_dir=str(out_dir),
        prefix="k4",
    )

    plot_interaction_3d(
        model,
        0, 1,
        psi_true_fn=psi12_star,
        save_dir=str(out_dir),
        prefix="k4",
        gridN=61,
        center=True,
    )


if __name__ == "__main__":
    main()
