import argparse
import os
import pickle as pkl
import numpy as np
import torch
from torch.optim import Adam
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegression
from estimation_weighted import *


# ================= Utils =================

def flatten_full(X_list):
    # Same flattening logic as debias.py: alt1 - alt0 and drop first feature (assumed constant)
    return [X[1][1:] - X[0][1:] for X in X_list]


def get_diff_features(X_list):
    X_arr = np.asarray(X_list)
    return X_arr[:, 1, 1:] - X_arr[:, 0, 1:]


def one_hot_z(z):
    z = np.asarray(z, dtype=int)
    z_oh = np.zeros((len(z), 3))
    z_oh[np.arange(len(z)), z + 1] = 1
    return z_oh


def create_z_interactions(X_diff, z, top_indices=[1, 6, 0, 8, 9]):
    z_numeric = np.asarray(z).reshape(-1, 1)
    return z_numeric * X_diff[:, top_indices]


def prepare_features(X_list, z, include_z_onehot=True, include_interactions=True):
    X_diff = get_diff_features(X_list)
    feats = [X_diff]
    if include_z_onehot:
        feats.append(one_hot_z(z))
    if include_interactions:
        feats.append(create_z_interactions(X_diff, z))
    return np.hstack(feats)

def mnl_prob(X_i: np.ndarray, beta: np.ndarray) -> np.ndarray:
    """
    X_i: (J,d), beta: (d,)
    return softmax probs (J,)
    """
    u = X_i @ beta  # (J,)
    u = u - np.max(u)
    expu = np.exp(u)
    return expu / np.sum(expu)

def score_psi(
    X_i: np.ndarray,
    y_onehot_i: np.ndarray,
    w_i: float,
    g_hat_i: np.ndarray,
    e_hat_i: float,
    beta: np.ndarray,
) -> np.ndarray:
    """Orthogonal score psi_i for MNL (vector in R^d)."""
    p_i = mnl_prob(X_i, beta)  # (J,)
    adj = p_i - g_hat_i + (w_i / e_hat_i) * (g_hat_i - y_onehot_i)
    return X_i.T @ adj  # (d,)

def compute_se(
    X_mnl: np.ndarray,
    y_onehot: np.ndarray,
    w: np.ndarray,
    g_hat: np.ndarray,
    e_hat: np.ndarray,
    beta_hat: np.ndarray,
) -> np.ndarray:
    """Sandwich SE for beta_hat using cross-fitted nuisances."""
    n, J, d = X_mnl.shape

    psi = np.zeros((n, d), dtype=float)
    for i in range(n):
        psi[i] = score_psi(X_mnl[i], y_onehot[i], w[i], g_hat[i], e_hat[i], beta_hat)

    V = (psi.T @ psi) / n

    Jhat = np.zeros((d, d), dtype=float)
    for i in range(n):
        p = mnl_prob(X_mnl[i], beta_hat)
        S = np.diag(p) - np.outer(p, p)
        Jhat += X_mnl[i].T @ S @ X_mnl[i]
    Jhat /= n

    try:
        Jinv = np.linalg.inv(Jhat)
    except np.linalg.LinAlgError:
        Jinv = np.linalg.pinv(Jhat)

    Var_beta = Jinv @ V @ Jinv / n
    se = np.sqrt(np.clip(np.diag(Var_beta), 0.0, np.inf))
    return se


# ================= Estimation Logic =================

def run_primary_only(X_all, y_real, real_rows):
    """
    Match debias.py primary weights:
      - w_p is one-hot on realized y
      - fit(..., seed=0)
    """
    X_p = [X_all[r] for r in real_rows]
    y_p = np.asarray([y_real[r] for r in real_rows], dtype=int)
    w_p = np.zeros((len(y_p), 2), dtype=int)
    w_p[np.arange(len(y_p)), y_p] = 1
    return fit(X_p, w_p, seed=0)


def run_dml(
    X_all, y_real, y_aug,
    real_rows, aug_rows,
    n_folds=5, clip_eps=0.02,
    seed=0, n_epochs=4000, lr=5e-3
):

    # -----------------
    # Inputs
    # -----------------
    real_rows = np.asarray(real_rows, dtype=int)
    aug_rows = np.asarray(aug_rows, dtype=int)

    X_p_raw = [X_all[r] for r in real_rows]
    y_p = np.asarray(y_real[real_rows], dtype=int)
    z_p = np.asarray(y_aug[real_rows], dtype=int)

    X_a_raw = [X_all[r] for r in aug_rows]
    z_a = np.asarray(y_aug[aug_rows], dtype=int)

    n_p = len(y_p)
    n_a = len(z_a)
    if n_p == 0 and n_a == 0:
        raise ValueError("Both primary and augmented samples are empty.")

    # Build X_mnl = drop first (constant) column, keep both alternatives
    def drop_const(Xi):
        Xi = np.asarray(Xi)
        return Xi[:, 1:]  # (J, d)

    X_p_mnl = np.stack([drop_const(x) for x in X_p_raw], axis=0) if n_p > 0 else np.zeros((0, 2, 0))
    X_a_mnl = np.stack([drop_const(x) for x in X_a_raw], axis=0) if n_a > 0 else np.zeros((0, 2, X_p_mnl.shape[2]))

    if n_p > 0 and n_a > 0 and X_p_mnl.shape[2] != X_a_mnl.shape[2]:
        raise ValueError(f"Feature dim mismatch: primary d={X_p_mnl.shape[2]} vs aug d={X_a_mnl.shape[2]}")

    d = X_p_mnl.shape[2] if n_p > 0 else X_a_mnl.shape[2]
    J = 2  # your setting

    # Primary one-hot labels (augmented y is missing, but w=0 => doesn't matter)
    y_onehot_p = np.zeros((n_p, J), dtype=float)
    if n_p > 0:
        y_onehot_p[np.arange(n_p), y_p] = 1.0
    y_onehot_a = np.zeros((n_a, J), dtype=float)  # arbitrary; w=0 makes it irrelevant

    # -----------------
    # Cross-fit g on primary, stratified by z
    # g_hat is a (J,) vector; we model it via prob of choosing alt1: p1
    # -----------------
    class_prior = float(np.mean(y_p)) if n_p > 0 else 0.5

    g1_oof = np.zeros(n_p, dtype=float)      # primary OOF prob(choice=1)
    g1_aug_sum = np.zeros(n_a, dtype=float)  # augmented prob(choice=1) aggregated across folds
    aug_fold_counts = 0

    if n_p > 0:
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
        z_values = np.unique(z_p) if n_p > 0 else np.array([-1, 0, 1])

        # features for g: flatten to (n_p, d) using alt1-alt0 (consistent with your earlier g)
        def flat_diff_from_mnl(Xmnl):
            return Xmnl[:, 1, :] - Xmnl[:, 0, :]  # (n, d)

        X_flat_p = flat_diff_from_mnl(X_p_mnl) if n_p > 0 else np.zeros((0, d))
        X_flat_a = flat_diff_from_mnl(X_a_mnl) if n_a > 0 else np.zeros((0, d))

        for tr, te in kf.split(X_flat_p):
            X_tr, y_tr, z_tr = X_flat_p[tr], y_p[tr], z_p[tr]

            g_models = {}
            for zv in z_values:
                mask = (z_tr == zv)
                if np.sum(mask) >= 2 and len(np.unique(y_tr[mask])) == 2:
                    g_models[int(zv)] = LogisticRegression(
                        C=0.05, max_iter=2000, random_state=seed
                    ).fit(X_tr[mask], y_tr[mask])
                else:
                    g_models[int(zv)] = None

            # OOF on primary
            for i in te:
                zv = int(z_p[i])
                m = g_models.get(zv)
                if m is not None:
                    g1_oof[i] = m.predict_proba(X_flat_p[i:i+1])[0, 1]
                else:
                    g1_oof[i] = class_prior

            # predict ALL augmented this fold
            if n_a > 0:
                g1_fold = np.empty(n_a, dtype=float)
                for j in range(n_a):
                    zv = int(z_a[j])
                    m = g_models.get(zv)
                    if m is not None:
                        g1_fold[j] = m.predict_proba(X_flat_a[j:j+1])[0, 1]
                    else:
                        g1_fold[j] = class_prior
                g1_aug_sum += g1_fold
                aug_fold_counts += 1

        g1_aug = g1_aug_sum / max(1, aug_fold_counts) if n_a > 0 else np.zeros(0, dtype=float)
    else:
        # no primary => no cross-fit; use class_prior everywhere
        g1_oof = np.zeros(0, dtype=float)
        g1_aug = np.full(n_a, class_prior, dtype=float)

    # Build g_hat matrices (n, J)
    g_hat_p = np.column_stack([1.0 - g1_oof, g1_oof]) if n_p > 0 else np.zeros((0, J))
    g_hat_a = np.column_stack([1.0 - g1_aug, g1_aug]) if n_a > 0 else np.zeros((0, J))

    # -----------------
    # e_hat: constant propensity of being primary (clipped)
    # -----------------
    n_total = n_p + n_a
    if n_total == 0:
        raise ValueError("No data.")
    e_const = n_p / n_total
    e_const = float(np.clip(e_const, clip_eps, 1.0 - clip_eps))

    e_hat_p = np.full(n_p, e_const, dtype=float)
    e_hat_a = np.full(n_a, e_const, dtype=float)

    # w: indicator of being primary in the combined sample
    w_p = np.ones(n_p, dtype=float)
    w_a = np.zeros(n_a, dtype=float)

    # -----------------
    # Construct combined arrays
    # -----------------
    X_mnl = np.concatenate([X_p_mnl, X_a_mnl], axis=0) if n_a > 0 else X_p_mnl
    y_onehot = np.concatenate([y_onehot_p, y_onehot_a], axis=0) if n_a > 0 else y_onehot_p
    g_hat = np.concatenate([g_hat_p, g_hat_a], axis=0) if n_a > 0 else g_hat_p
    e_hat = np.concatenate([e_hat_p, e_hat_a], axis=0) if n_a > 0 else e_hat_p
    w = np.concatenate([w_p, w_a], axis=0) if n_a > 0 else w_p

    # -----------------
    # Point estimate beta_hat: minimize CE with tau = g + (w/e)*(y-g)
    # (augmented: w=0 => tau=g)
    # -----------------
    tau = g_hat + (w[:, None] / e_hat[:, None]) * (y_onehot - g_hat)  # (n,J)

    # Optional: keep it numerically safe (not changing core logic, just stability)
    tau = np.clip(tau, 0.0, 1.0)
    tau = tau / np.clip(tau.sum(axis=1, keepdims=True), 1e-12, np.inf)

    # Torch optimization
    torch.manual_seed(seed)
    X_t = torch.tensor(X_mnl, dtype=torch.float64)         # (n,J,d)
    tau_t = torch.tensor(tau, dtype=torch.float64)    # (n,J)

    beta = torch.nn.Parameter(torch.zeros(d, dtype=torch.float64))
    opt = Adam([beta], lr=lr)

    for _ in range(n_epochs):
        opt.zero_grad()
        u = torch.einsum("njd,d->nj", X_t, beta)  # (n,J)
        u = u - torch.max(u, dim=1, keepdim=True).values
        p = torch.softmax(u, dim=1)               # (n,J)
        loss = -torch.mean(torch.sum(tau_t * torch.log(torch.clamp(p, 1e-12, 1.0)), dim=1))
        loss.backward()
        opt.step()

    beta_hat = beta.detach().cpu().numpy()

    # -----------------
    # ORIGINAL SE (your provided definition)
    # -----------------
    se = compute_se(X_mnl, y_onehot, w, g_hat, e_hat, beta_hat)

    return beta_hat, se


# ================= Main =================

def main(
    dx: int,
    n_samples: int,
    n_trials: int,
    n_real: int,
    method: str,
    n_folds: int,
    clip_eps: float,
):
    """
    This main matches the debias.py sampling-path idea:
      - One RNG stream with RandomState(0)
      - For each trial: draw participants once -> rows once
      - For each n_aug: use prefix rows[n_real : n_real+n_aug] (same path)
    """
    os.makedirs("res_dml", exist_ok=True)

    # --- Keep your existing logic: always load the 1200 training file ---
    with open(f"data/train_{method}_{dx}_1200.pkl", "rb") as f:
        data = pkl.load(f)[0]

    y_real = np.asarray(data["y"])
    y_aug = np.asarray(data["y_aug"])
    X_all = list(data["X"])

    n_max_aug, step_aug = 1000, 100
    n_max = n_real + n_max_aug
    rng = np.random.RandomState(0)  # single global RNG stream (important)

    results = {
        "n_real_list": [],
        "n_aug_list": [],
        "sample_id_list": [],
        "params_list": [],
        "se_list": [],
    }

    for sid in range(n_trials):
        # --- Sample path ONCE per trial ---
        participants = rng.choice(
            int(n_max / 5),
            size=int((n_real + n_max_aug) / 5),
            replace=False,
        )
        rows = []
        for j in participants:
            rows += list(range(j * 5, (j * 5) + 5))
        rows = np.asarray(rows, dtype=int)

        real_r = rows[:n_real]

        # --- For each n_aug, use the SAME rows path ---
        for n_aug in range(0, n_max_aug + 1, step_aug):
            aug_r = rows[n_real : n_real + n_aug]

            if n_aug == 0:
                beta = run_primary_only(
                    X_all, y_real, real_r
                )
                se = None
            else:
                beta, se = run_dml(
                    X_all, y_real, y_aug,
                    real_r, aug_r,
                    n_folds=n_folds,
                    clip_eps=clip_eps,
                    seed=0,  # keep seed=0 
                )

            results["params_list"].append(beta)
            results["n_real_list"].append(n_real)
            results["n_aug_list"].append(n_aug)
            results["sample_id_list"].append(sid)
            results["se_list"].append(se)

    res_file = f"res_dml/dml_{method}_{n_real}_{n_max_aug}_{n_trials}.pkl"
    with open(res_file, "wb") as f:
        pkl.dump(results, f, -1)

    print(f"Saved: {res_file}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--dx", type=int, default=11)
    p.add_argument("--n_samples", type=int, default=1200)  # kept for CLI compat; ignored for file loading
    p.add_argument("--n_trials", type=int, required=True)
    p.add_argument("--n_real", type=int, required=True)
    p.add_argument("--method", type=str, required=True)
    p.add_argument("--n_folds", type=int, default=5)
    p.add_argument("--clip_epsilon", type=float, default=0.02)
    args = p.parse_args()

    main(
        dx=args.dx,
        n_samples=args.n_samples,
        n_trials=args.n_trials,
        n_real=args.n_real,
        method=args.method,
        n_folds=args.n_folds,
        clip_eps=args.clip_epsilon,
    )