from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from causal_chambers_plotting import plot_single_panel  # noqa: E402
from tqdm import tqdm


# -----------------------------
# Imports from your repo
# -----------------------------
def _import_estimators():
    # Estimators
    from unpaired_iv.estimators import TSIV, UPGMM, UPGMMConfig

    # Prefer analytic UP-GMM-HD if present; otherwise fall back to split-based one.
    up_hd = None
    up_hd_name = None
    try:
        from unpaired_iv.estimators import UPGMMHDAnalytic, UPGMMHDAnalyticConfig

        up_hd = UPGMMHDAnalytic(UPGMMHDAnalyticConfig(l1=False))
        up_hd_name = "up_gmm_hd"
    except Exception:
        try:
            from unpaired_iv.estimators import UPGMMHD, UPGMMHDConfig

            up_hd = UPGMMHD(UPGMMHDConfig(K=2, redraw_B=10, l1=False))
            up_hd_name = "up_gmm_hd"
        except Exception as e:
            raise ImportError(
                "Could not import UP-GMM-HD estimator (analytic or split). "
                "Check your unpaired_iv.estimators exports."
            ) from e

    # Data container
    from unpaired_iv.data import UnpairedIVData

    return TSIV, UPGMM, UPGMMConfig, up_hd, up_hd_name, UnpairedIVData


# -----------------------------
# Core utilities
# -----------------------------
def ols_beta_x_given_red(df: pd.DataFrame) -> float:
    """
    Ground-truth linear causal effect using observed confounder:
        ir_2 ~ 1 + ir_1 + red
    Returns coefficient on ir_1.
    """
    y = df["ir_2"].to_numpy(dtype=float)
    x = df["ir_1"].to_numpy(dtype=float)
    r = df["red"].to_numpy(dtype=float)

    X = np.column_stack([np.ones_like(x), x, r])
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    return float(beta[1])


@dataclass(frozen=True)
class EnvCache:
    env_values: np.ndarray  # sorted unique env ids (original labels)
    env_to_pos: Dict[int, int]  # map original env id -> [0..m-1]
    X_by_env: List[np.ndarray]  # each: (r,)
    Y_by_env: List[np.ndarray]  # each: (r,)


def build_env_cache(
    df: pd.DataFrame, r_expected: int = 8, seed: int = 0
) -> Tuple[pd.DataFrame, EnvCache]:
    """
    Filters df to environments with >= r_expected rows, then samples exactly r_expected rows per env
    (to enforce r fixed), and builds per-env arrays for fast subsampling/unpairing.
    """
    rng = np.random.default_rng(seed)

    # keep only columns we need
    needed = ["red", "led_1_ir", "ir_1", "ir_2"]
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")

    # enforce r_expected rows per env
    counts = df["led_1_ir"].value_counts()
    ok_envs = counts[counts >= r_expected].index.to_numpy()

    df_ok = df[df["led_1_ir"].isin(ok_envs)].copy()

    # sample exactly r_expected per env (important if any env has > r_expected)
    parts = []
    for env, g in df_ok.groupby("led_1_ir", sort=False):
        if len(g) == r_expected:
            parts.append(g)
        else:
            take = rng.choice(g.index.to_numpy(), size=r_expected, replace=False)
            parts.append(g.loc[take])
    df_fixed = pd.concat(parts, axis=0, ignore_index=True)

    # build cache
    env_values = np.sort(df_fixed["led_1_ir"].unique())
    env_to_pos = {int(e): i for i, e in enumerate(env_values.tolist())}

    X_by_env: List[np.ndarray] = []
    Y_by_env: List[np.ndarray] = []
    for e in env_values:
        ge = df_fixed[df_fixed["led_1_ir"] == e]
        # exactly r_expected rows
        X_by_env.append(ge["ir_1"].to_numpy(dtype=float))
        Y_by_env.append(ge["ir_2"].to_numpy(dtype=float))

    cache = EnvCache(
        env_values=env_values,
        env_to_pos=env_to_pos,
        X_by_env=X_by_env,
        Y_by_env=Y_by_env,
    )
    return df_fixed, cache


def unpair_balanced_within_env(
    cache: EnvCache,
    env_pos_list: np.ndarray,  # positions in [0..m_total-1]
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create unpaired samples by, within each env, randomly assigning half rows to Y-sample
    and the other half to X-sample. Uses all rows -> total n/m stays fixed.

    Returns:
      I_y: (n_y, m_sub) one-hot
      Y:   (n_y,)
      I_x: (n_x, m_sub) one-hot
      X:   (n_x, 1)
    """
    m_sub = int(env_pos_list.size)

    # accumulate indices as env-coded integers [0..m_sub-1]
    y_env_idx: List[int] = []
    y_vals: List[float] = []
    x_env_idx: List[int] = []
    x_vals: List[float] = []

    for j, env_pos in enumerate(env_pos_list):
        X_e = cache.X_by_env[int(env_pos)]
        Y_e = cache.Y_by_env[int(env_pos)]
        r = int(X_e.shape[0])
        assert r == Y_e.shape[0]

        perm = rng.permutation(r)
        r_y = r // 2
        r_x = r - r_y

        iy = perm[:r_y]
        ix = perm[r_y:]

        y_env_idx.extend([j] * r_y)
        y_vals.extend(Y_e[iy].tolist())

        x_env_idx.extend([j] * r_x)
        x_vals.extend(X_e[ix].tolist())

    y_env_idx = np.asarray(y_env_idx, dtype=int)
    x_env_idx = np.asarray(x_env_idx, dtype=int)

    I_y = np.eye(m_sub, dtype=float)[y_env_idx]
    I_x = np.eye(m_sub, dtype=float)[x_env_idx]
    Y = np.asarray(y_vals, dtype=float)
    X = np.asarray(x_vals, dtype=float).reshape(-1, 1)

    return I_y, Y, I_x, X


def pick_m_list(m_total: int, r_expected: int = 8, min_N: int = 400) -> List[int]:
    """
    Choose an increasing list of m values up to m_total (so N = m*r_expected increases),
    roughly similar density to your Setting 2 N_list.
    """
    # Target N ladder, then map to m via ceil(N/r) so all curves start at min_N.
    N_candidates = [400, 1600, 6400, 12000]
    N_candidates = [N for N in N_candidates if N >= min_N]
    m_list = [int(np.ceil(N / r_expected)) for N in N_candidates]
    m_list = [m for m in m_list if 2 <= m <= m_total]
    if m_total not in m_list:
        m_list.append(m_total)
    m_list = sorted(set(m_list))

    # ensure each m gives at least 2 envs and at least 2 samples per side (needs r>=2 anyway)
    return m_list


# -----------------------------
# Experiment runner
# -----------------------------
def run_one_dataset(
    name: str,
    df: pd.DataFrame,
    r_expected: int,
    m_list: List[int],
    n_rep: int,
    seed: int,
) -> Tuple[float, pd.DataFrame, Dict[str, str]]:
    TSIV, UPGMM, UPGMMConfig, up_hd, up_hd_name, UnpairedIVData = _import_estimators()

    beta_true = ols_beta_x_given_red(df)

    df_fixed, cache = build_env_cache(df, r_expected=r_expected, seed=seed)
    m_total = cache.env_values.size
    if m_list is None or len(m_list) == 0:
        m_list = pick_m_list(m_total, r_expected=r_expected, min_N=400)
    else:
        m_list = [m for m in m_list if 2 <= m <= m_total]
        if not m_list:
            raise ValueError(f"Provided m_list has no valid values in [2, {m_total}]")

    estimators = {
        "ts_iv": TSIV(),
        "up_gmm": UPGMM(UPGMMConfig(l1=False, use_optimal_weight=True, split_B=0)),
        up_hd_name: up_hd,
    }
    labels = {
        "ts_iv": "TS-IV",
        "up_gmm": "UP-GMM",
        up_hd_name: "SplitUP",
    }

    rows = []
    base_rng = np.random.default_rng(seed)

    env_positions_all = np.arange(m_total, dtype=int)

    m_iter = tqdm(m_list, desc=f"{name} m", leave=False) if tqdm else m_list
    for m_sub in m_iter:
        N_total = int(
            m_sub * r_expected
        )  # total unpaired (X+Y) observations; ratio N/m = r_expected

        rep_iter = (
            tqdm(range(n_rep), desc=f"{name} rep (m={m_sub})", leave=False)
            if tqdm
            else range(n_rep)
        )
        for rep in rep_iter:
            rng = np.random.default_rng(base_rng.integers(0, 2**32 - 1))

            if m_sub == m_total:
                env_pos_list = env_positions_all
            else:
                env_pos_list = rng.choice(env_positions_all, size=m_sub, replace=False)

            # build unpaired data for this replicate
            I_y, Y, I_x, X = unpair_balanced_within_env(cache, env_pos_list, rng=rng)

            data = UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)

            for key, est in estimators.items():
                bhat = est.fit(data, rng=rng)
                bhat = float(np.asarray(bhat).reshape(-1)[0])  # d=1
                err = abs(bhat - beta_true)
                rows.append(
                    {
                        "dataset": name,
                        "m": m_sub,
                        "r": r_expected,
                        "N": N_total,
                        "rep": rep,
                        "estimator": key,
                        "label": labels[key],
                        "beta_true": beta_true,
                        "beta_hat": bhat,
                        "abs_error": err,
                    }
                )

    res = pd.DataFrame(rows)
    return beta_true, res, labels


def summarize_results(res: pd.DataFrame) -> pd.DataFrame:
    g = res.groupby(["dataset", "N", "m", "r", "estimator", "label"], as_index=False)
    out = g.agg(
        mae=("abs_error", "mean"),
        sd=("abs_error", "std"),
        beta_true=("beta_true", "first"),
    )
    # standard error
    counts = res.groupby(["dataset", "N", "estimator"]).size().reset_index(name="n")
    out = out.merge(counts, on=["dataset", "N", "estimator"], how="left")
    out["se"] = out["sd"] / np.sqrt(out["n"].clip(lower=1))
    return out.sort_values(["dataset", "N", "label"]).reset_index(drop=True)


# -----------------------------
# Main
# -----------------------------
def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--standard-16", type=str, default="./causal_chambers/df_standard_16.csv"
    )
    ap.add_argument("--outdir", type=str, default="results_causal_chambers")
    ap.add_argument("--n-rep", type=int, default=50)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument(
        "--m-list",
        type=str,
        default="",
        help="Comma-separated m values (e.g. '50,100,200,400,800,1000'). If empty, auto-pick.",
    )
    ap.add_argument("--no-xlog", action="store_true")
    args = ap.parse_args()

    outdir = Path(args.outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    m_list = None
    if args.m_list.strip():
        m_list = [int(x) for x in args.m_list.split(",") if x.strip()]

    dataset_name = "m/n=16"
    data_path = Path(args.standard_16)
    if not data_path.exists():
        raise FileNotFoundError(f"Missing dataset: {data_path}")
    df = pd.read_csv(data_path)
    _, res, estimator_labels = run_one_dataset(
        name=dataset_name,
        df=df,
        r_expected=16,
        m_list=m_list,
        n_rep=args.n_rep,
        seed=args.seed,
    )
    summary = summarize_results(res)

    plot_single_panel(
        summary,
        outdir / "realworld_setting2_r8_mae_mn16.pdf",
        estimator_labels=estimator_labels,
        dataset_name="m/n=16",
        x_log=(not args.no_xlog),
        x_col="m",
        x_label="number of environments ($m$)",
    )

    print("\nSaved:")
    print(f"  {outdir / 'realworld_setting2_r8_mae_mn16.pdf'}")


if __name__ == "__main__":
    main()
