from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import numpy as np  # noqa: E402

from unpaired_iv_experiments.dgp import (  # noqa: E402
    DiscreteEnvDGPCorrelated,
    LowRankEnvDGPConfig,
    LowRankEnvDGPCorrelated,
    make_dense_beta,
    make_sparse_beta,
)
from unpaired_iv_experiments.plotting import (  # noqa: E402
    apply_mpl_style,
    plot_experiment_vs_N,
    plot_grid_experiment,
)
from unpaired_iv_experiments.presets import (  # noqa: E402
    build_estimators_for_dense,
    build_estimators_for_sparse,
)
from unpaired_iv_experiments.runners import (  # noqa: E402
    run_grid_by_rN,
    run_grid_fixed_m_by_N,
)

from .utils import LABELS_DENSE, LABELS_SPARSE, save_pdf


@dataclass(frozen=True)
class Setting1Config:
    m: int = 100
    d: int = 200
    s_star: int = 10
    N_list: Tuple[int, ...] = (400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400)
    n_rep: int = 50


@dataclass(frozen=True)
class Setting2Config:
    d: int = 2
    r_list: Tuple[int, ...] = (4, 8, 16, 32)
    N_list: Tuple[int, ...] = (
        400,
        800,
        1600,
        3200,
        6400,
        12800,
        25600,
        51200,
    )
    n_rep: int = 50


@dataclass(frozen=True)
class Setting4Config:
    m: int = 50
    d: int = 2
    N_list: Tuple[int, ...] = (
        800,
        1600,
        3200,
        6400,
        12800,
        25600,
        51200,
        102400,
    )
    n_rep: int = 50


@dataclass(frozen=True)
class Setting5Config:
    k: int = 60
    d: int = 100
    s_star: int = 10
    r_list: Tuple[int, ...] = (4, 8, 16, 32)
    N_list: Tuple[int, ...] = (400, 800, 1600, 3200, 6400, 12800, 25600, 51200)
    n_rep: int = 50


def main() -> None:
    results_dir = Path("results_current")
    results_dir.mkdir(parents=True, exist_ok=True)
    apply_mpl_style()

    dgp = DiscreteEnvDGPCorrelated()

    BASE_SEED = 1234
    master_ss = np.random.SeedSequence(BASE_SEED)
    exp_ss = master_ss.spawn(5)

    labels = LABELS_DENSE
    labels_sparse = LABELS_SPARSE

    # -------------------------
    # Experiment 1 (Setting 1): Sparse beta, fixed m, fixed d>m, N->∞ via r->∞
    # -------------------------
    set1 = Setting1Config()
    est1 = build_estimators_for_sparse()

    beta_sparse_1 = make_sparse_beta(DiscreteEnvDGPCorrelated, set1.d, set1.s_star)

    rng_seed1 = int(exp_ss[0].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 1 with rng_seed {rng_seed1}")
    res1 = run_grid_fixed_m_by_N(
        dgp=dgp,
        estimators=est1,
        m_fixed=set1.m,
        N_list=list(set1.N_list),
        beta_fn=beta_sparse_1,
        n_rep=set1.n_rep,
        rng_seed=rng_seed1,
        use_cache=False,
    )
    fig1 = plot_experiment_vs_N(
        res1, labels_sparse, N_list=list(set1.N_list), y_lim=(-2, 30)
    )
    save_pdf(fig1, results_dir, "exp01_setting1_fixed_m_sparse_mae_l1_ablation")

    # -------------------------
    # Experiment 2 (Setting 2): Dense beta, high-dimensional instruments
    # (m, n to infty and n/m = r fixed)
    # -------------------------
    set2 = Setting2Config()
    est2 = build_estimators_for_dense()

    beta_dense_2 = make_dense_beta(set2.d)

    rng_seed2 = int(exp_ss[1].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 2 with rng_seed {rng_seed2}")
    res2 = run_grid_by_rN(
        dgp=dgp,
        estimators=est2,
        r_list=list(set2.r_list),
        N_list=list(set2.N_list),
        beta_fn=beta_dense_2,
        n_rep=set2.n_rep,
        rng_seed=rng_seed2,
        return_rep_errs=True,
        store_estimates_for=["up_gmm_hd", "up_gmm_hd_analytic"],
        use_cache=False,
    )

    fig2 = plot_grid_experiment(
        res2, list(set2.r_list), labels, N_list=list(set2.N_list), y_lim=(-0.01, 0.4)
    )
    save_pdf(
        fig2,
        results_dir,
        "exp02_setting2_many_instruments_dense_grid_mae_l1_ablation",
    )

    # -------------------------
    # Experiment 4 (Setting 4): Fixed m, fixed d<m, dense beta
    # -------------------------
    set4 = Setting4Config()
    est4 = build_estimators_for_dense()

    beta_dense_4 = make_dense_beta(set4.d)

    rng_seed4 = int(exp_ss[3].generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 4 with rng_seed {rng_seed4}")
    res4 = run_grid_fixed_m_by_N(
        dgp=dgp,
        estimators=est4,
        m_fixed=set4.m,
        N_list=list(set4.N_list),
        beta_fn=beta_dense_4,
        n_rep=set4.n_rep,
        rng_seed=rng_seed4,
        use_cache=False,
    )

    fig4 = plot_experiment_vs_N(
        res4, labels, N_list=list(set4.N_list), y_lim=(-0.01, 0.2)
    )
    save_pdf(fig4, results_dir, "exp04_setting4_fixed_m_dense_mae_l1_ablation")

    # -------------------------
    # Experiment 5: Low-rank env means (m,n->∞ with n/m=r fixed) + sparse beta.
    # -------------------------
    set5 = Setting5Config()
    est5 = build_estimators_for_sparse()
    ss5_dgp, ss5_run = exp_ss[4].spawn(2)
    dgp5 = LowRankEnvDGPCorrelated(
        LowRankEnvDGPConfig(k=set5.k), d=set5.d, rng=np.random.default_rng(ss5_dgp)
    )

    beta_sparse_5 = make_sparse_beta(DiscreteEnvDGPCorrelated, set5.d, set5.s_star)

    rng_seed5 = int(ss5_run.generate_state(1, dtype=np.uint32)[0])
    print(f"running experiment 5 with rng_seed {rng_seed5}")
    res5 = run_grid_by_rN(
        dgp=dgp5,
        estimators=est5,
        r_list=list(set5.r_list),
        N_list=list(set5.N_list),
        beta_fn=beta_sparse_5,
        n_rep=set5.n_rep,
        rng_seed=rng_seed5,
        use_cache=False,
    )
    fig5 = plot_grid_experiment(
        res5,
        list(set5.r_list),
        labels_sparse,
        N_list=list(set5.N_list),
        y_lim=(0.1, 30),
    )
    save_pdf(
        fig5,
        results_dir,
        "exp05_lowrank_env_sparse_beta_grid_mae_l1_ablation",
    )


if __name__ == "__main__":
    main()
