from __future__ import annotations

from pathlib import Path
from typing import Dict

import numpy as np  # noqa: E402

from .utils import LABELS_SPARSE, save_pdf
from unpaired_iv_experiments.dgp import (  # noqa: E402
    DiscreteEnvDGP,
    DiscreteEnvDGPConfig,
    make_sparse_beta,
)
from unpaired_iv.estimators import (
    TSIV,
    UPGMM,
    Estimator,
    UPGMMConfig,
    UPGMMHDAnalytic,
    UPGMMHDAnalyticConfig,
)
from unpaired_iv_experiments.configs import (  # noqa: E402
    Setting6Config,
)
from unpaired_iv_experiments.runners import (  # noqa: E402
    run_grid_by_rN,
)
from unpaired_iv_experiments.plotting import (  # noqa: E402
    apply_mpl_style,
    plot_grid_experiment,
)


def build_estimators_for_sparse() -> Dict[str, Estimator]:
    return {
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(
            UPGMMConfig(
                l1=True,
                lam_scale=0.0002,
                post_refit=True,
                use_optimal_weight=True,
                split_B=0,
            )
        ),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(
            UPGMMHDAnalyticConfig(
                l1=True, lam_scale=0.0002, post_refit=True, lam_effective_n="m"
            )
        ),
    }


dgp = DiscreteEnvDGP(
    DiscreteEnvDGPConfig(
        gamma_x=1.0,
        gamma_y=1.0,
        sigma_u=0.02,
        sigma_x=0.02,
        sigma_eps=0.02,
        env_sigma=0.02,
    )
)


labels_sparse = LABELS_SPARSE


def main() -> None:
    results_dir = Path("results_current")
    results_dir.mkdir(parents=True, exist_ok=True)
    apply_mpl_style()

    BASE_SEED = 1234
    master_ss = np.random.SeedSequence(BASE_SEED)
    exp_ss = master_ss.spawn(5)

    # -------------------------
    # Experiment 6 (Setting 6): Dense beta, high-dimensional instruments
    # (m, n to infty and n/m = r fixed)
    # -------------------------
    set6 = Setting6Config()
    est6 = build_estimators_for_sparse()

    beta_sparse_6 = make_sparse_beta(DiscreteEnvDGP, set6.d, set6.s_star)

    rng_seed6 = int(exp_ss[1].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 6 with rng_seed {rng_seed6}")
    res6 = run_grid_by_rN(
        dgp=dgp,
        estimators=est6,
        r_list=list(set6.r_list),
        N_list=list(set6.N_list),
        beta_fn=beta_sparse_6,
        n_rep=set6.n_rep,
        rng_seed=rng_seed6,
        use_cache=True,
    )

    fig6 = plot_grid_experiment(
        res6,
        list(set6.r_list),
        labels_sparse,
        N_list=list(set6.N_list),
        y_lim=(-5, 50),
    )
    save_pdf(fig6, results_dir, "exp06_setting6_many_instruments_dense_grid_mae_l1")


if __name__ == "__main__":
    main()
