from __future__ import annotations

from pathlib import Path
from typing import Dict, Iterable

import matplotlib.pyplot as plt

LABELS_DENSE: Dict[str, str] = {
    "naive_ols": "Naive OLS",
    "ts_2sls": "TS-2SLS",
    "ts_iv": "TS-IV",
    "up_gmm": "UP-GMM",
    "up_gmm_hd": r"$\mathrm{SplitUP}$",
    "up_gmm_hd_moment": r"$\mathrm{SplitUP}$ (moment lasso)",
    "up_gmm_hd_analytic": r"$\mathrm{SplitUP}$",
}

LABELS_SPARSE: Dict[str, str] = {
    "naive_ols": "Naive OLS",
    "ts_2sls": "TS-2SLS",
    "ts_iv": "TS-IV",
    "up_gmm": r"UP-GMM ($\ell_1$)",
    "up_gmm_hd": r"$\mathrm{SplitUP}$ ($\ell_1$)",
    "up_gmm_hd_moment": r"$\mathrm{SplitUP}$ ($\ell_1$) (moment lasso)",
    "up_gmm_hd_analytic": r"$\mathrm{SplitUP}$ ($\ell_1$)",
}


def pick_labels(keys: Iterable[str], labels: Dict[str, str]) -> Dict[str, str]:
    return {k: labels[k] for k in keys if k in labels}


def save_pdf(fig: plt.Figure, results_dir: Path, stem: str) -> None:
    results_dir.mkdir(parents=True, exist_ok=True)
    fig.savefig(results_dir / f"{stem}.pdf", bbox_inches="tight")
    plt.close(fig)



