from __future__ import annotations

from pathlib import Path

from typing import Dict

import numpy as np  # noqa: E402

from unpaired_iv_experiments.dgp import (  # noqa: E402
    ContinuousIVDGP,
    LowRankContinuousIVDGP,
    LowRankContinuousIVDGPConfig,
    make_dense_beta,
    make_sparse_beta,
)
from .utils import LABELS_DENSE, LABELS_SPARSE, save_pdf
from unpaired_iv.estimators import (
    TS2SLS,
    TSIV,
    UPGMM,
    Estimator,
    UPGMMConfig,
    UPGMMHDAnalytic,
    UPGMMHDAnalyticConfig,
)
from unpaired_iv_experiments.configs import (  # noqa: E402
    Setting1Config,
    Setting2Config,
    Setting5Config,
)
from unpaired_iv_experiments.runners import (  # noqa: E402
    run_grid_by_rN,
    run_grid_fixed_m_by_N,
)
from unpaired_iv_experiments.plotting import (  # noqa: E402
    apply_mpl_style,
    plot_experiment_vs_N,
    plot_grid_experiment,
)


def build_estimators_for_dense() -> Dict[str, Estimator]:
    return {
        "ts_2sls": TS2SLS(),
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(UPGMMConfig(l1=False, use_optimal_weight=True, split_B=0)),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False)),
    }


def build_estimators_for_sparse() -> Dict[str, Estimator]:
    return {
        "ts_2sls": TS2SLS(),
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(
            UPGMMConfig(
                l1=True,
                lam_scale=1.0,
                post_refit=True,
                use_optimal_weight=True,
                split_B=0,
            )
        ),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(
            UPGMMHDAnalyticConfig(
                l1=True, lam_scale=1.0, post_refit=True, lam_effective_n="m"
            )
        ),
    }


def main() -> None:
    results_dir = Path("results_current")
    results_dir.mkdir(parents=True, exist_ok=True)
    apply_mpl_style()

    # Continuous instrument DGP (I is Gaussian in R^m, shared first-stage map Π per dataset)
    dgp = ContinuousIVDGP()

    BASE_SEED = 1234
    master_ss = np.random.SeedSequence(BASE_SEED)
    exp_ss = master_ss.spawn(5)

    labels = LABELS_DENSE
    labels_sparse = LABELS_SPARSE

    # -------------------------
    # Experiment 1 (Setting 1): Sparse beta, fixed m, fixed d>m, N->∞ via r->∞
    # -------------------------
    set1 = Setting1Config()
    est1 = build_estimators_for_sparse()

    beta_sparse_1 = make_sparse_beta(ContinuousIVDGP, set1.d, set1.s_star)

    rng_seed1 = int(exp_ss[0].generate_state(1, dtype=np.uint32)[0])
    print(f"running continuous experiment 1 with rng_seed {rng_seed1}")
    res1 = run_grid_fixed_m_by_N(
        dgp=dgp,
        estimators=est1,
        m_fixed=set1.m,
        N_list=list(set1.N_list),
        beta_fn=beta_sparse_1,
        n_rep=set1.n_rep,
        rng_seed=rng_seed1,
        use_cache=True,
    )
    fig1 = plot_experiment_vs_N(
        res1, labels_sparse, N_list=list(set1.N_list), y_lim=(-2, 30), legend_above=True
    )
    save_pdf(fig1, results_dir, "exp01_setting1_fixed_m_sparse_mae_l1_continuous")

    # -------------------------
    # Experiment 2 (Setting 2): Dense beta, high-dimensional instruments
    # (m, n to infty and n/m = r fixed)
    # -------------------------
    set2 = Setting2Config()
    est2 = build_estimators_for_dense()

    beta_dense_2 = make_dense_beta(set2.d)

    rng_seed2 = int(exp_ss[1].generate_state(1, dtype=np.uint32)[0])
    print(f"running continuous experiment 2 with rng_seed {rng_seed2}")
    res2 = run_grid_by_rN(
        dgp=dgp,
        estimators=est2,
        r_list=list(set2.r_list),
        N_list=list(set2.N_list),
        beta_fn=beta_dense_2,
        n_rep=set2.n_rep,
        rng_seed=rng_seed2,
        use_cache=True,
    )

    fig2 = plot_grid_experiment(
        res2, list(set2.r_list), labels, N_list=list(set2.N_list), y_lim=(-0.05, 1.0)
    )
    save_pdf(
        fig2,
        results_dir,
        "exp02_setting2_many_instruments_dense_grid_mae_l1_continuous",
    )

    # -------------------------
    # Experiment 5: Low-rank first-stage Π (m,n->∞ with n/m=r fixed) + sparse beta.
    # -------------------------
    set5 = Setting5Config()
    est5 = build_estimators_for_sparse()
    ss5_dgp, ss5_run = exp_ss[4].spawn(2)
    dgp5 = LowRankContinuousIVDGP(
        LowRankContinuousIVDGPConfig(k=set5.k),
        d=set5.d,
        rng=np.random.default_rng(ss5_dgp),
    )

    beta_sparse_5 = make_sparse_beta(ContinuousIVDGP, set5.d, set5.s_star)

    rng_seed5 = int(ss5_run.generate_state(1, dtype=np.uint32)[0])
    print(f"running continuous experiment 5 with rng_seed {rng_seed5}")
    res5 = run_grid_by_rN(
        dgp=dgp5,
        estimators=est5,
        r_list=list(set5.r_list),
        N_list=list(set5.N_list),
        beta_fn=beta_sparse_5,
        n_rep=set5.n_rep,
        rng_seed=rng_seed5,
        use_cache=True,
    )
    fig5 = plot_grid_experiment(
        res5,
        list(set5.r_list),
        labels_sparse,
        N_list=list(set5.N_list),
        y_lim=(0.1, 30),
    )
    save_pdf(
        fig5,
        results_dir,
        "exp05_lowrank_env_sparse_beta_grid_mae_l1_continuous",
    )


if __name__ == "__main__":
    main()
