from __future__ import annotations

from typing import Any, Callable, Dict, List, Optional

import numpy as np
from tqdm import tqdm

from .cache import (
    cache_load,
    cache_save,
    cache_version,
    estimators_signature,
    jsonable,
)
from unpaired_iv.estimators import Estimator
from unpaired_iv.stats import mean_ci

from .dgp import DiscreteEnvDGP


def run_grid_fixed_m_by_N(
    dgp: DiscreteEnvDGP,
    estimators: Dict[str, Estimator],
    m_fixed: int,
    N_list: List[int],
    beta_fn: Callable[[np.random.Generator], np.ndarray],
    n_rep: int,
    rng_seed: int = 0,
    r_x_mult: float = 1.0,
    r_y_mult: float = 1.0,
    *,
    use_cache: bool = True,
    cache_dir: str = "results_current/cache",
    cache_tag: str = "",
) -> Dict[str, List[Dict[str, float]]]:
    """Run experiments with fixed m and increasing total N."""
    if use_cache:
        payload = {
            "version": cache_version(),
            "cache_tag": str(cache_tag),
            "dgp": jsonable(getattr(dgp, "cfg", None)),
            "estimators": estimators_signature(estimators),
            "m_fixed": int(m_fixed),
            "N_list": list(map(int, N_list)),
            "n_rep": int(n_rep),
            "rng_seed": int(rng_seed),
            "r_x_mult": float(r_x_mult),
            "r_y_mult": float(r_y_mult),
            "beta_fn": jsonable(beta_fn),
        }
        cached = cache_load(cache_dir, "run_grid_fixed_m_by_N", payload)
        if cached is not None:
            return cached

    rng = np.random.default_rng(rng_seed)
    results: Dict[str, List[Dict[str, float]]] = {k: [] for k in estimators.keys()}

    total_steps = len(N_list) * n_rep
    pbar = tqdm(total=total_steps, desc="run_grid_fixed_m_by_N")

    denom = r_x_mult + r_y_mult
    if denom <= 0:
        raise ValueError("Need r_x_mult + r_y_mult > 0.")

    for N in N_list:
        r_base = int(N // max(int(m_fixed * denom), 1))
        r_y = int(np.floor(r_y_mult * r_base))
        r_x = int(np.floor(r_x_mult * r_base))

        if m_fixed < 2 or r_y < 1 or r_x < 1:
            for k in estimators.keys():
                results[k].append(
                    {
                        "N": float("nan"),
                        "m": m_fixed,
                        "r_y": r_y,
                        "r_x": r_x,
                        "mean": np.nan,
                        "lower": np.nan,
                        "upper": np.nan,
                        "median": np.nan,
                    }
                )
            pbar.update(n_rep)
            continue

        N_eff = m_fixed * (r_y + r_x)

        errs: Dict[str, List[float]] = {k: [] for k in estimators.keys()}
        for _ in range(n_rep):
            local = np.random.default_rng(rng.integers(1_000_000_000))
            beta = beta_fn(local)
            data = dgp.sample_unpaired(
                m=m_fixed, r_y=r_y, r_x=r_x, beta=beta, rng=local
            )

            for k, est in estimators.items():
                bhat = est.fit(data, rng=local)
                errs[k].append(float(np.sum(np.abs(bhat - beta))))

            pbar.update(1)

        for k in estimators.keys():
            mu, lo, up, med = mean_ci(errs[k])
            results[k].append(
                {
                    "N": float(N_eff),
                    "m": float(m_fixed),
                    "r_y": float(r_y),
                    "r_x": float(r_x),
                    "mean": float(mu),
                    "lower": float(lo),
                    "upper": float(up),
                    "median": float(med),
                }
            )

    pbar.close()
    if use_cache:
        cache_save(cache_dir, "run_grid_fixed_m_by_N", payload, results)
    return results


def run_grid_by_rN(
    dgp: DiscreteEnvDGP,
    estimators: Dict[str, Estimator],
    r_list: List[int],
    N_list: List[int],
    beta_fn: Callable[[np.random.Generator], np.ndarray],
    n_rep: int,
    rng_seed: int = 0,
    *,
    use_cache: bool = True,
    cache_dir: str = "results_current/cache",
    cache_tag: str = "",
    return_rep_errs: bool = False,
    store_estimates_for: Optional[List[str]] = None,
) -> Dict[int, Dict[str, List[Dict[str, Any]]]]:
    """Run experiments varying r and total N."""
    if store_estimates_for is not None:
        store_set = set(store_estimates_for)
    else:
        store_set = set()

    if use_cache:
        payload = {
            "version": cache_version(),
            "cache_tag": str(cache_tag),
            "dgp": jsonable(getattr(dgp, "cfg", None)),
            "estimators": estimators_signature(estimators),
            "r_list": list(map(int, r_list)),
            "N_list": list(map(int, N_list)),
            "n_rep": int(n_rep),
            "rng_seed": int(rng_seed),
            "beta_fn": jsonable(beta_fn),
            "return_rep_errs": bool(return_rep_errs),
            "store_estimates_for": sorted(list(store_set)),
        }
        cached = cache_load(cache_dir, "run_grid_by_rN", payload)
        if cached is not None:
            return cached

    rng = np.random.default_rng(rng_seed)
    results: Dict[int, Dict[str, List[Dict[str, Any]]]] = {
        int(r): {k: [] for k in estimators.keys()} for r in r_list
    }

    total_steps = len(r_list) * len(N_list) * n_rep
    pbar = tqdm(total=total_steps, desc="run_grid_by_rN")

    for r in r_list:
        r = int(r)
        for N in N_list:
            N = int(N)
            m = int(N // (2 * r))

            if m < 2:
                for k in estimators.keys():
                    rec: Dict[str, Any] = {
                        "N": float(N),
                        "m": float(m),
                        "mean": np.nan,
                        "lower": np.nan,
                        "upper": np.nan,
                        "median": np.nan,
                    }
                    if return_rep_errs:
                        rec["rep_errs"] = [float("nan")] * int(n_rep)
                    if store_estimates_for is not None and k in store_set:
                        rec["rep_beta"] = []
                        rec["rep_bhat"] = []
                        rec["rep_seeds"] = []
                    results[int(r)][k].append(rec)
                pbar.update(n_rep)
                continue

            errs: Dict[str, List[float]] = {k: [] for k in estimators.keys()}
            rep_errs: Dict[str, List[float]] = {k: [] for k in estimators.keys()}
            rep_beta: Dict[str, List[List[float]]] = {k: [] for k in estimators.keys()}
            rep_bhat: Dict[str, List[List[float]]] = {k: [] for k in estimators.keys()}
            rep_seeds: List[int] = []

            for _ in range(n_rep):
                local_seed = int(rng.integers(1_000_000_000))
                local = np.random.default_rng(local_seed)
                beta = beta_fn(local)
                data = dgp.sample_unpaired(m=m, r_y=r, r_x=r, beta=beta, rng=local)

                rep_seeds.append(local_seed)
                for k, est in estimators.items():
                    bhat = est.fit(data, rng=local)
                    err = float(np.sum(np.abs(bhat - beta)))
                    errs[k].append(err)
                    if return_rep_errs:
                        rep_errs[k].append(err)
                    if store_estimates_for is not None and k in store_set:
                        rep_beta[k].append(beta.tolist())
                        rep_bhat[k].append(np.asarray(bhat).tolist())

                pbar.update(1)

            for k in estimators.keys():
                mu, lo, up, med = mean_ci(errs[k])
                rec = {
                    "N": float(N),
                    "m": float(m),
                    "mean": float(mu),
                    "lower": float(lo),
                    "upper": float(up),
                    "median": float(med),
                }
                if return_rep_errs:
                    rec["rep_errs"] = rep_errs[k]
                if store_estimates_for is not None and k in store_set:
                    rec["rep_beta"] = rep_beta[k]
                    rec["rep_bhat"] = rep_bhat[k]
                    rec["rep_seeds"] = rep_seeds
                results[int(r)][k].append(rec)

    pbar.close()
    if use_cache:
        cache_save(cache_dir, "run_grid_by_rN", payload, results)
    return results


def run_heatmap_by_m_tau(
    dgp: DiscreteEnvDGP,
    estimators: Dict[str, Estimator],
    m_list: List[int],
    tau_list: List[int],
    beta_fn: Callable[[np.random.Generator], np.ndarray],
    n_rep: int,
    rng_seed: int = 0,
    *,
    use_cache: bool = True,
    cache_dir: str = "results_current/cache",
    cache_tag: str = "",
) -> Dict[str, np.ndarray]:
    """Run heatmap experiments over m and tau."""
    if use_cache:
        payload = {
            "version": cache_version(),
            "cache_tag": str(cache_tag),
            "dgp": jsonable(getattr(dgp, "cfg", None)),
            "estimators": estimators_signature(estimators),
            "m_list": list(map(int, m_list)),
            "tau_list": list(map(int, tau_list)),
            "n_rep": int(n_rep),
            "rng_seed": int(rng_seed),
            "beta_fn": jsonable(beta_fn),
        }
        cached = cache_load(cache_dir, "run_heatmap_by_m_tau", payload)
        if cached is not None:
            return cached

    rng = np.random.default_rng(rng_seed)

    mats = {k: np.zeros((len(tau_list), len(m_list))) for k in estimators.keys()}
    total_steps = len(m_list) * len(tau_list) * n_rep
    pbar = tqdm(total=total_steps, desc="run_heatmap_by_m_tau")

    for i, m in enumerate(m_list):
        m = int(m)
        for j, tau in enumerate(tau_list):
            tau = int(tau)
            errs: Dict[str, List[float]] = {k: [] for k in estimators.keys()}

            for _ in range(n_rep):
                local = np.random.default_rng(rng.integers(1_000_000_000))
                beta = beta_fn(local)
                data = dgp.sample_unpaired(m=m, r_y=tau, r_x=tau, beta=beta, rng=local)

                for k, est in estimators.items():
                    bhat = est.fit(data, rng=local)
                    errs[k].append(float(np.sum(np.abs(bhat - beta))))
                pbar.update(1)

            for k in estimators.keys():
                mats[k][j, i] = np.mean(errs[k])

    pbar.close()
    if use_cache:
        cache_save(cache_dir, "run_heatmap_by_m_tau", payload, mats)
    return mats


def run_heatmap_by_N_tau(
    dgp: DiscreteEnvDGP,
    estimators: Dict[str, Estimator],
    N_list: List[int],
    tau_list: List[int],
    beta_fn: Callable[[np.random.Generator], np.ndarray],
    n_rep: int,
    rng_seed: int = 0,
    *,
    use_cache: bool = True,
    cache_dir: str = "results_current/cache",
    cache_tag: str = "",
) -> Dict[str, np.ndarray]:
    """Run heatmap experiments over N and tau."""
    if use_cache:
        payload = {
            "version": cache_version(),
            "cache_tag": str(cache_tag),
            "dgp": jsonable(getattr(dgp, "cfg", None)),
            "estimators": estimators_signature(estimators),
            "N_list": list(map(int, N_list)),
            "tau_list": list(map(int, tau_list)),
            "n_rep": int(n_rep),
            "rng_seed": int(rng_seed),
            "beta_fn": jsonable(beta_fn),
        }
        cached = cache_load(cache_dir, "run_heatmap_by_N_tau", payload)
        if cached is not None:
            return cached

    rng = np.random.default_rng(rng_seed)

    mats = {k: np.zeros((len(tau_list), len(N_list))) for k in estimators.keys()}
    total_steps = len(N_list) * len(tau_list) * n_rep
    pbar = tqdm(total=total_steps, desc="run_heatmap_by_N_tau")

    for i, N in enumerate(N_list):
        N = int(N)
        for j, tau in enumerate(tau_list):
            tau = int(tau)
            m = int(N // (2 * tau))
            if m < 2:
                for k in mats.keys():
                    mats[k][j, i] = np.nan
                pbar.update(n_rep)
                continue

            errs: Dict[str, List[float]] = {k: [] for k in estimators.keys()}
            for _ in range(n_rep):
                local = np.random.default_rng(rng.integers(1_000_000_000))
                beta = beta_fn(local)
                data = dgp.sample_unpaired(m=m, r_y=tau, r_x=tau, beta=beta, rng=local)

                for k, est in estimators.items():
                    bhat = est.fit(data, rng=local)
                    errs[k].append(float(np.sum(np.abs(bhat - beta))))
                pbar.update(1)

            for k in estimators.keys():
                mats[k][j, i] = np.mean(errs[k])

    pbar.close()
    if use_cache:
        cache_save(cache_dir, "run_heatmap_by_N_tau", payload, mats)
    return mats

