from __future__ import annotations

from typing import Dict

from unpaired_iv.estimators.base import Estimator
from unpaired_iv.estimators.baselines import TS2SLS, TSIV
from unpaired_iv.estimators.upgmm import UPGMM, UPGMMConfig
from unpaired_iv.estimators.upgmm_hd import UPGMMHDAnalytic, UPGMMHDAnalyticConfig


def build_estimators_for_dense() -> Dict[str, Estimator]:
    """Build estimator presets for dense coefficients."""
    return {
        # "naive_ols": NaiveOLS(),
        "ts_2sls": TS2SLS(),
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(UPGMMConfig(l1=False, use_optimal_weight=True, split_B=0)),
        # "up_gmm_hd": UPGMMHD(UPGMMHDConfig(K=2, redraw_B=10, l1=False)),
        # "up_gmm_hd_moment": UPGMMHD_Moment(
        #     UPGMMHDMomentConfig(K=2, redraw_B=10, l1=False)
        # ),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False)),
        # "jn": JN(JNConfig(l1=False, split_B=10)),
    }


def build_estimators_for_sparse() -> Dict[str, Estimator]:
    """Build estimator presets for sparse coefficients."""
    return {
        # "naive_ols": NaiveOLS(),
        "ts_2sls": TS2SLS(),
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(
            UPGMMConfig(
                l1=True,
                lam_scale=1.0,
                post_refit=True,
                use_optimal_weight=True,
                split_B=0,
            )
        ),
        # "up_gmm_hd": UPGMMHD(
        #     UPGMMHDConfig(
        #         K=2,
        #         redraw_B=10,
        #         l1=True,
        #         lam_scale=1.0,
        #         post_refit=True,
        #         lam_effective_n="m",
        #     )
        # ),
        # "up_gmm_hd_moment": UPGMMHD_Moment(
        #     UPGMMHDMomentConfig(
        #         K=2,
        #         redraw_B=10,
        #         l1=True,
        #         lam_scale=1.0,
        #         post_refit=True,
        #         lam_effective_n="m",
        #     )
        # ),
        "up_gmm_hd_analytic": UPGMMHDAnalytic(
            UPGMMHDAnalyticConfig(
                l1=True, lam_scale=1.0, post_refit=True, lam_effective_n="m"
            )
        ),
        # "jn": JN(
        #     JNConfig(
        #         l1=True, lam_scale=1.0, post_refit=True, split_B=10, lam_effective_n="m"
        #     )
        # ),
    }

