from __future__ import annotations

from pathlib import Path
from typing import Dict

from .utils import LABELS_DENSE, pick_labels, save_pdf
from unpaired_iv_experiments.dgp import DiscreteEnvDGP, make_dense_beta
from unpaired_iv.estimators import (
    TSIV,
    Estimator,
    UPGMMHDAnalytic,
    UPGMMHDAnalyticConfig,
)
from unpaired_iv_experiments.runners import run_heatmap_by_m_tau
from unpaired_iv_experiments.plotting import apply_mpl_style, plot_two_heatmaps


def build_estimators_for_dense() -> Dict[str, Estimator]:
    return {
        "ts_iv": TSIV(),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False)),
    }


def run_heatmaps() -> None:
    out_dir = Path("results_current")
    out_dir.mkdir(parents=True, exist_ok=True)
    apply_mpl_style()

    dgp = DiscreteEnvDGP()

    est_all = build_estimators_for_dense()
    estimators = {
        "ts_iv": est_all["ts_iv"],
        "up_gmm_hd_analytic": est_all["up_gmm_hd_analytic"],
    }
    heat_keys = ("ts_iv", "up_gmm_hd_analytic")
    labels = pick_labels(heat_keys, LABELS_DENSE)

    # Dense beta for heatmaps (keep it fixed across runs for interpretability)
    d = 2

    dense_beta = make_dense_beta(d)

    rng_seed = 123

    # -------------------------
    # Heatmap 1: sweep (m, tau=r)
    # -------------------------
    m_list = [50, 100, 200, 400, 800, 1600, 3200]
    tau_list = [2, 4, 8, 16, 32]

    mats_mtau = run_heatmap_by_m_tau(
        dgp=dgp,
        estimators=estimators,
        m_list=m_list,
        tau_list=tau_list,
        beta_fn=dense_beta,
        n_rep=50,
        rng_seed=rng_seed,
        use_cache=True,
        cache_tag="heatmap_m_tau_continuous",
    )

    fig = plot_two_heatmaps(
        mats=mats_mtau,
        est_keys=heat_keys,
        estimator_labels=labels,
        x_vals=m_list,
        y_vals=tau_list,
        x_label="number of environments ($m$) ($\\times 10^2$)",
        y_label="observations per\nenvironment ($\\tau=n/m$)",
        log_x=True,
        log_y=True,
        y_ticks=tau_list,
        y_ticklabels=[str(t) for t in tau_list],
        x_ticks=m_list,
        x_ticklabels=[f"{m / 100:g}" for m in m_list],
        figsize=(10, 4),
        font_size=18,
        title_size=20,
        x_tick_rotation=0,
        x_tick_ha="center",
    )
    save_pdf(fig, out_dir, "heatmap_m_tau")


if __name__ == "__main__":
    run_heatmaps()
