from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple

import numpy as np  # noqa: E402

from unpaired_iv.estimators import (
    TSIV,
    UPGMM,
    UPGMMHD,
    Estimator,
    UPGMMConfig,
    UPGMMHDAnalytic,
    UPGMMHDAnalyticConfig,
    UPGMMHDConfig,
)
from unpaired_iv_experiments.configs import (  # noqa: E402
    Setting1Config,
    Setting2Config,
    Setting4Config,
    Setting5Config,
)
from unpaired_iv_experiments.dgp import (  # noqa: E402
    DiscreteEnvDGP,
    LowRankEnvDGP,
    LowRankEnvDGPConfig,
    make_dense_beta,
    make_sparse_beta,
)
from unpaired_iv_experiments.plotting import (  # noqa: E402
    apply_mpl_style,
    plot_agreement_grid_scatter,
    plot_experiment_vs_N,
    plot_grid_experiment,
)
from unpaired_iv_experiments.runners import (  # noqa: E402
    run_grid_by_rN,
    run_grid_fixed_m_by_N,
)

from .utils import LABELS_DENSE, LABELS_SPARSE, save_pdf


def build_estimators_for_dense() -> Dict[str, Estimator]:
    return {
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(UPGMMConfig(l1=False, use_optimal_weight=True, split_B=0)),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False)),
    }


def build_estimators_for_sparse() -> Dict[str, Estimator]:
    return {
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(
            UPGMMConfig(
                l1=True,
                lam_scale=1.0,
                post_refit=True,
                use_optimal_weight=True,
                split_B=0,
            )
        ),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(
            UPGMMHDAnalyticConfig(
                l1=True, lam_scale=1.0, post_refit=True, lam_effective_n="m"
            )
        ),
    }


def main() -> None:
    results_dir = Path("results_current")
    results_dir.mkdir(parents=True, exist_ok=True)
    apply_mpl_style()

    dgp = DiscreteEnvDGP()

    # One master seed for the whole script. Each experiment gets a deterministic
    # child seed derived from this, so experiments are reproducible but not
    # identical copies of each other.
    BASE_SEED = 1234
    master_ss = np.random.SeedSequence(BASE_SEED)
    exp_ss = master_ss.spawn(5)

    labels = LABELS_DENSE
    labels_sparse = LABELS_SPARSE

    # -------------------------
    # Experiment 1 (Setting 1): Sparse beta, fixed m, fixed d>m, N->∞ via r->∞
    # -------------------------
    set1 = Setting1Config()
    est1 = build_estimators_for_sparse()

    beta_sparse_1 = make_sparse_beta(DiscreteEnvDGP, set1.d, set1.s_star)

    rng_seed1 = int(exp_ss[0].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 1 with rng_seed {rng_seed1}")
    res1 = run_grid_fixed_m_by_N(
        dgp=dgp,
        estimators=est1,
        m_fixed=set1.m,
        N_list=list(set1.N_list),
        beta_fn=beta_sparse_1,
        n_rep=set1.n_rep,
        rng_seed=rng_seed1,
        use_cache=True,
    )
    fig1 = plot_experiment_vs_N(
        res1, labels_sparse, N_list=list(set1.N_list), y_lim=(-2, 30)
    )
    save_pdf(fig1, results_dir, "exp01_setting1_fixed_m_sparse_mae_l1")

    # -------------------------
    # Experiment 2 (Setting 2): Dense beta, high-dimensional instruments
    # (m, n to infty and n/m = r fixed)
    # -------------------------
    set2 = Setting2Config()
    est2 = build_estimators_for_dense()

    beta_dense_2 = make_dense_beta(set2.d)

    rng_seed2 = int(exp_ss[1].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 2 with rng_seed {rng_seed2}")
    res2 = run_grid_by_rN(
        dgp=dgp,
        estimators=est2,
        r_list=list(set2.r_list),
        N_list=list(set2.N_list),
        beta_fn=beta_dense_2,
        n_rep=set2.n_rep,
        rng_seed=rng_seed2,
        return_rep_errs=True,
        store_estimates_for=["up_gmm_hd", "up_gmm_hd_analytic"],
        use_cache=True,
    )

    fig2 = plot_grid_experiment(
        res2, list(set2.r_list), labels, N_list=list(set2.N_list), y_lim=(-0.01, 0.4)
    )
    save_pdf(fig2, results_dir, "exp02_setting2_many_instruments_dense_grid_mae_l1")

    # -------------------------
    # Experiment 2 (Setting 2): Agreement between analytic and grid-based SplitUP
    # -------------------------
    def build_estimators_for_dense_agreement() -> Dict[str, Estimator]:
        return {
            "up_gmm_hd": UPGMMHD(UPGMMHDConfig(K=2, redraw_B=10, l1=False)),
            "up_gmm_hd_analytic": UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False)),
        }

    @dataclass(frozen=True)
    class Setting2Config_Agreement:
        d: int = 2
        r_list: Tuple[int, ...] = (4, 8, 16, 32)
        N_list: Tuple[int, ...] = (800, 1600, 3200, 6400, 12800)
        n_rep: int = 50

    set2_agreement = Setting2Config_Agreement()
    est2_agreement = build_estimators_for_dense_agreement()

    beta_dense_2_agreement = make_dense_beta(set2_agreement.d)

    rng_seed2_agreement = int(exp_ss[1].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 2 with rng_seed {rng_seed2_agreement}")
    res2_agreement = run_grid_by_rN(
        dgp=dgp,
        estimators=est2_agreement,
        r_list=list(set2_agreement.r_list),
        N_list=list(set2_agreement.N_list),
        beta_fn=beta_dense_2_agreement,
        n_rep=set2_agreement.n_rep,
        rng_seed=rng_seed2_agreement,
        return_rep_errs=True,
        store_estimates_for=["up_gmm_hd", "up_gmm_hd_analytic"],
        use_cache=True,
    )

    fig2_agree = plot_agreement_grid_scatter(
        res2_agreement,
        r_list=list(set2_agreement.r_list),
        key_base="up_gmm_hd",
        key_ana="up_gmm_hd_analytic",
    )
    save_pdf(
        fig2_agree,
        results_dir,
        "exp02_agreement_up_gmm_hd_vs_analytic_scatter_grid",
    )

    # -------------------------
    # Experiment 4 (Setting 4): Fixed m, fixed d<m, dense beta
    # -------------------------
    set4 = Setting4Config()
    est4 = build_estimators_for_dense()

    beta_dense_4 = make_dense_beta(set4.d)

    rng_seed4 = int(exp_ss[3].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 4 with rng_seed {rng_seed4}")
    res4 = run_grid_fixed_m_by_N(
        dgp=dgp,
        estimators=est4,
        m_fixed=set4.m,
        N_list=list(set4.N_list),
        beta_fn=beta_dense_4,
        n_rep=set4.n_rep,
        rng_seed=rng_seed4,
        use_cache=True,
    )

    fig4 = plot_experiment_vs_N(
        res4, labels, N_list=list(set4.N_list), y_lim=(-0.01, 0.2)
    )
    save_pdf(fig4, results_dir, "exp04_setting4_fixed_m_dense_mae_l1")

    # -------------------------
    # Experiment 5: Low-rank env means (m,n->∞ with n/m=r fixed) + sparse beta.
    # -------------------------
    set5 = Setting5Config()
    est5 = build_estimators_for_sparse()
    ss5_dgp, ss5_run = exp_ss[4].spawn(2)
    dgp5 = LowRankEnvDGP(
        LowRankEnvDGPConfig(k=set5.k), d=set5.d, rng=np.random.default_rng(ss5_dgp)
    )

    beta_sparse_5 = make_sparse_beta(DiscreteEnvDGP, set5.d, set5.s_star)

    rng_seed5 = int(ss5_run.generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 5 with rng_seed {rng_seed5}")
    res5 = run_grid_by_rN(
        dgp=dgp5,
        estimators=est5,
        r_list=list(set5.r_list),
        N_list=list(set5.N_list),
        beta_fn=beta_sparse_5,
        n_rep=set5.n_rep,
        rng_seed=rng_seed5,
        use_cache=True,
    )
    fig5 = plot_grid_experiment(
        res5,
        list(set5.r_list),
        labels_sparse,
        N_list=list(set5.N_list),
        y_lim=(0.1, 30),
    )
    save_pdf(fig5, results_dir, "exp05_lowrank_env_sparse_beta_grid_mae_l1")


if __name__ == "__main__":
    main()
