from __future__ import annotations

import os
import sys
import time
import argparse
import logging
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
from tqdm.auto import tqdm

# ------------------------ Path setup  --------------------------

sys.path.append(".")
sys.path.append("../")
sys.path.append("../../")

import src.utils.io as io
from src.datasets.generate_simulated_data import (
    MultivarOutcomeRegressionParams,
    simulate_multivariate_outcome_regression,
)
import src.methods.sinkhorn_div as sinkhorn
import src.methods.eif as eif
import src.utils.mmd as mmd
import src.utils.stat as stat
import src.methods.propensity_models as propensity_models
from src.methods.NF_conditional_dist_est import (
    estimate_P_matrix,
    ConditionalFlowEstimator,
)

# ------------------------------ CLI ------------------------------

def parse_arguments() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()

    # ---------------- System / output ----------------
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n_sims", type=int, default=1000)
    parser.add_argument("--save_every", type=int, default=10)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--results_tag", type=str, default="splitting")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument( "--run_name", type=str, required=True, help="Unique name for this run/process.",)

    # ---------------- Data ----------------
    parser.add_argument("--d_x", type=int, default=3)
    parser.add_argument("--d_y", type=int, default=2)

    # ------------ XGB propensity hyperparams -----------
    parser.add_argument("--xgb_max_depth", type=int, default=3)
    parser.add_argument("--xgb_min_child_weight", type=float, default=1.0)
    parser.add_argument("--xgb_gamma", type=float, default=0.0)
    parser.add_argument("--xgb_subsample", type=float, default=0.8)
    parser.add_argument("--xgb_colsample_bytree", type=float, default=0.8)
    parser.add_argument("--xgb_learning_rate", type=float, default=0.05)
    parser.add_argument("--xgb_n_estimators", type=int, default=400)
    parser.add_argument("--xgb_reg_lambda", type=float, default=1.0)
    parser.add_argument("--xgb_reg_alpha", type=float, default=0.0)
    parser.add_argument("--xgb_test_size", type=float, default=0.2)
    parser.add_argument("--xgb_threshold", type=float, default=0.5)
    parser.add_argument("--xgb_n_jobs", type=int, default=-1)
    parser.add_argument("--xgb_tree_method", type=str, default="hist")

    # ---------------- NF ----------------
    parser.add_argument("--nf_lr", type=float, default=1e-3)
    parser.add_argument("--nf_hidden_dim", type=int, default=32)
    parser.add_argument("--nf_n_layers", type=int, default=4)
    parser.add_argument("--nf_batch_size", type=int, default=256)
    parser.add_argument("--nf_n_epochs", type=int, default=200)

    # ---------------- OT / kernel bandwidth ----------------
    parser.add_argument("--eps", type=float, default=10.0)

    # --------- Hypothesis test ---------
    parser.add_argument(
        "--hypothesis_test_cutoffs",
        type=float,
        nargs="+",
        default=[0.90, 0.95, 0.99],
        help="List of quantile cutoffs in [0,1]. Ex: --hypothesis_test_cutoffs 0.9 0.95 0.99",
    )

    # ---------------- grid ----------------
    parser.add_argument("--n_grid", type=int, nargs="+", default=[250, 500, 1000, 2000, 4000], help="Sample sizes. Ex: --n_grid 500 1000 2000")
    parser.add_argument("--theta_values", type=float, nargs="+", required=True, help="List of theta values to run. Ex: --theta_values 0.0 0.1 0.2",)

    return parser

@torch.no_grad()
def compute_truth(
    delta: np.ndarray,
    B: np.ndarray,
    T: np.ndarray,
    Sigma: np.ndarray,
    eps: float,
    device: torch.device,
    dtype: torch.dtype,
) -> Tuple[float, float]:
    """
    Returns (sink_true, mmd_true) as floats.
    """
    mu_Y1 = torch.tensor(delta, dtype=dtype, device=device)
    mu_Y0 = torch.zeros_like(mu_Y1)

    Sigma_Y1 = torch.tensor((B + T) @ (B + T).T + Sigma, dtype=dtype, device=device)
    Sigma_Y0 = torch.tensor(B @ B.T + Sigma, dtype=dtype, device=device)

    sink_true = sinkhorn.sinkhorn_divergence_gaussians(mu_Y1, Sigma_Y1, mu_Y0, Sigma_Y0, eps=eps)
    mmd_true = mmd.squared_mmd_gaussians(mu_Y1, Sigma_Y1, mu_Y0, Sigma_Y0, eps=eps)
    return float(sink_true.item() if torch.is_tensor(sink_true) else sink_true), float(mmd_true.item() if torch.is_tensor(mmd_true) else mmd_true)


if __name__ == '__main__':

    parser = parse_arguments()
    args = parser.parse_args()

    # Reproducibility
    io.set_reproducibility(args.seed)
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # Output layout
    root_results_dir = os.path.join(io.get_results_dir(), args.results_tag)
    io.ensure_dir(root_results_dir)
    run_dir = os.path.join(root_results_dir, args.run_name)
    io.ensure_dir(run_dir)

    # Logger
    log_path = os.path.join(run_dir, "run.log")
    logger = io.configure_logger(log_path, args.verbose)
    logger.info("Starting run: %s", args.run_name)
    logger.info("Device: %s", device)

    # Resume
    csv_path = os.path.join(run_dir, "results.csv")
    if os.path.exists(csv_path):
        raise RuntimeError(f"{csv_path} already exists. Choose a new --run_name or delete the directory.")
    done = set()
    if args.resume:
        done = io.load_done_keys_from_csv(csv_path)
        logger.info("Resume enabled. Found %d completed (theta, sim, n) triples in %s", len(done), csv_path)
    else:
        logger.info("Resume disabled.")

    # Files
    ckpt_path = os.path.join(run_dir, "checkpoint.npz")        # periodic snapshot
    meta_path = os.path.join(run_dir, "meta.json")             # run metadata

    # Persist metadata (atomic)
    meta = {
        "run_name": args.run_name,
        "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "args": vars(args),
        "device": str(device),
        "torch_version": torch.__version__,
        "numpy_version": np.__version__,
    }
    io.atomic_write_bytes(meta_path, io.stable_json_dumps(meta).encode("utf-8"))

    # Sample sizes
    n_grid = [int(x) for x in args.n_grid]
    theta_values = np.array(args.theta_values, dtype=np.float32)
    n_theta = len(theta_values)


    # ---------------- Data-generating parameters ----------------
    gamma0 = 0.0
    gamma = np.ones(args.d_x, dtype=float)

    # Use a single RNG for parameters (fixed)
    rng_params = np.random.default_rng(args.seed)
    B = rng_params.normal(0.0, 1.0, size=(args.d_y, args.d_x))
    T_base = rng_params.normal(0.0, 0.5, size=(args.d_y, args.d_x))
    Sigma = np.eye(args.d_y, dtype=float)

    # Truth arrays
    sink_true_by_theta = np.full(n_theta, np.nan, dtype=float)
    mmd_true_by_theta  = np.full(n_theta, np.nan, dtype=float)

    # Propensity config (static across runs)
    xgb_cfg = propensity_models.build_xgb_cfg(args)

    pending_flush: List[Dict[str, Any]] = []

    for i, theta in enumerate(theta_values):

        theta = float(theta)
        delta = theta * np.ones(args.d_y, dtype=float)
        T = theta * T_base

        params = MultivarOutcomeRegressionParams(
            d_x=args.d_x,
            d_y=args.d_y,
            gamma0=gamma0,
            gamma=gamma,
            B=B,
            T=T,
            delta=delta,
            Sigma=Sigma,
        )

        # ---- Truth ----
        sink_true, mmd_true = compute_truth(delta=delta, B=B, T=T, Sigma=Sigma, eps=args.eps, device=device, dtype=dtype)
        sink_true_by_theta[i] = sink_true
        mmd_true_by_theta[i] = mmd_true

        # ---- Sim loop ----
        sim_loop = tqdm(range(args.n_sims), desc="Simulations", position=1, file=sys.stderr, leave=False)
        for j in sim_loop:

            # IMPORTANT: unique RNG per (theta, sim) so sims are not identical across j
            rng_data = np.random.default_rng(args.seed + 1000 * i + j)

            # Generate one dataset at max n (prefix used for each n)
            X_np, A_np, Y_np = simulate_multivariate_outcome_regression(
                n=int(2*n_grid[-1]),
                params=params,
                rng=rng_data,
                return_propensity=False,
            )

            # Torch tensors once; slice views
            X_all = torch.tensor(X_np, dtype=dtype)
            A_all = torch.tensor(A_np, dtype=torch.long)
            Y_all = torch.tensor(Y_np, dtype=dtype)

            for n in n_grid:

                key = (int(i), int(j), int(n))
                if args.resume and key in done:
                    continue

                n = int(n)
                X_hat, X_n = X_all[:n], X_all[-n:]
                A_hat, A_n = A_all[:n], A_all[-n:]
                Y_hat, Y_n = Y_all[:n], Y_all[-n:]

                # Cost / kernel
                C = sinkhorn.cost_matrix_on_atoms(Y_n, metric="sqeuclidean", normalize=False)
                heuristic_eps = sinkhorn.median_pairwise(C)**2
                print(f"Heuristic eps: {heuristic_eps}")
                K = mmd.compute_rbf_kernel_from_cost(C, eps=args.eps)

                # ---- Propensity A|X ----

                propensity_model = propensity_models.xgb_propensity_model(X_hat, A_hat, xgb_cfg)  # expected (n,2) or similar
                P_ax_np = propensity_model.predict_proba(X_n.detach().cpu().to(torch.float32).numpy())
                P_ax = torch.tensor(P_ax_np, dtype=dtype).permute(1,0)[[1,0], :]
                e_scores = P_ax[0]
                e_scores = torch.clamp(e_scores, 1e-8, 1.0 - 1e-8)
                
                # ---- Conditional NF for Y|X,A ----
                estimator = ConditionalFlowEstimator(
                    y_dim=args.d_y,
                    x_dim=args.d_x,
                    device=device,
                    hidden_dim=args.nf_hidden_dim,
                    n_layers=args.nf_n_layers,
                    lr=args.nf_lr,
                )
                estimator.fit(
                        X_hat,
                        A_hat,
                        Y_hat,
                        batch_size=args.nf_batch_size,
                        n_epochs=args.nf_n_epochs,
                    )

                # ---- Estimate P(Y | X, A=a) on atoms Y ----
                ones_a = torch.ones((X_n.shape[0], 1), dtype=dtype)
                zeros_a = torch.zeros((X_n.shape[0], 1), dtype=dtype)

                estimator.model.eval()
                with torch.inference_mode():
                    P_hat1 = estimate_P_matrix(estimator.model, X_n, ones_a, Y_n, device=device)   # (X, Y)
                    P_hat0 = estimate_P_matrix(estimator.model, X_n, zeros_a, Y_n, device=device)  # (X, Y)

                # Collapse over X to get marginals over atoms Y
                P1 = sinkhorn._validate_histogram(torch.mean(P_hat1, dim=0), name="P1")
                P0 = sinkhorn._validate_histogram(torch.mean(P_hat0, dim=0), name="P0")

                # P_yax expected shape (Y, A, X) in your EIF functions
                P_yax = torch.stack((P_hat1, P_hat0), dim=1).permute(2, 1, 0).to(dtype=dtype)  # (Y, A, X)

                # ---------------- MMD (plugin + EIF) ----------------
                with torch.no_grad():
                    deltaP = (P1 - P0).reshape(-1, 1)
                    mmd_u1 = K @ deltaP
                    mmd_u2 = -mmd_u1
                    mmd_plugin = float((deltaP.T @ mmd_u1).item()) / 2

                # First-order EIF (BUGFIX: u2 was incorrectly set to mmd_D1 in original)
                mmd_D1 = eif.first_order_EIF(
                    e=e_scores.to(device),
                    A=A_n.to(device),
                    u1=mmd_u1.to(device),
                    u2=mmd_u2.to(device),
                    P_yax=P_yax.to(device),
                )
                mmd_D1_mean = float(torch.mean(mmd_D1).item())
                mmd_D1_var = float(torch.var(mmd_D1, unbiased=False).item())
                mmd_D1_se = float(np.sqrt(mmd_D1_var / n))
                mmd_one_step = mmd_plugin + mmd_D1_mean
                mmd_ci_low = mmd_one_step - 1.96 * mmd_D1_se
                mmd_ci_high = mmd_one_step + 1.96 * mmd_D1_se

                # Second-order EIF
                omega = torch.vstack([1.0 / e_scores, -1.0 / (1.0 - e_scores)]).to(device=device, dtype=dtype)  # (2,n)
                O = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                  K=K.to(device), 
                                                  A=A_n.to(device), 
                                                  P_yax=P_yax.to(device), 
                                                  P_ax=P_ax.to(device))  # (n,n)
                mmd_D2 = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                       K=O.T.to(device), 
                                                       A=A_n.to(device), 
                                                       P_yax=P_yax.to(device), 
                                                       P_ax=P_ax.to(device))        # (n,n)
                mmd_U = stat.U_statistic_from_kernel(M=mmd_D2, w=None, weighted=False, exclude_diag=True)
                mmd_two_step = mmd_one_step + 0.5 * float(mmd_U)
                mmd_quant_out = stat.chaos_quantile(
                    mmd_D2 / (2*n),
                    cutoffs=args.hypothesis_test_cutoffs,
                    t_obs=n * mmd_two_step,
                    return_lambdas=False,
                )

                # ---------------- Sinkhorn divergence (plugin + EIF) ----------------

                ot_out = sinkhorn.sinkhorn_divergence_same_atoms(a=P1.to(device), 
                                                                 b=P0.to(device), 
                                                                 X=Y_n.to(device), 
                                                                 eps=args.eps)
                sink_plugin = float(ot_out["S"].item() if torch.is_tensor(ot_out["S"]) else ot_out["S"])

                sink_u1 = ot_out["potentials"]["ab"][0] - ot_out["potentials"]["aa"][0]
                sink_u2 = ot_out["potentials"]["ab"][1] - ot_out["potentials"]["bb"][1]

                sink_D1 = eif.first_order_EIF(
                    e=e_scores.to(device),
                    A=A_n.to(device),
                    u1=sink_u1.to(device),
                    u2=sink_u2.to(device),
                    P_yax=P_yax.to(device),
                )
                sink_D1_mean = float(torch.mean(sink_D1).item())
                sink_D1_var = float(torch.var(sink_D1, unbiased=False).item())
                sink_D1_se = float(np.sqrt(sink_D1_var / n))
                sink_one_step = sink_plugin + sink_D1_mean
                sink_ci_low = sink_one_step - 1.96 * sink_D1_se
                sink_ci_high = sink_one_step + 1.96 * sink_D1_se

                # Second-order Sinkhorn EIF
                f = ot_out["potentials"]["aa"][0].reshape(-1, 1)
                g = ot_out["potentials"]["aa"][1].reshape(1, -1)
                Q = torch.exp((f + g - C.to(device)) / args.eps).to(dtype=dtype, device=device)

                omega = torch.vstack([1.0 / e_scores, -1.0 / (1.0 - e_scores)]).to(device=device, dtype=dtype)  
                tK = sinkhorn.stable_hadamard_M_from_density_Q(Q=Q, w=P1, tau=1e-6, ridge=1e-5)
                O2 = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                   K=Q.to(device), 
                                                   A=A_n.to(device), 
                                                   P_yax=P_yax.to(device), 
                                                   P_ax=P_ax.to(device))  # (n,n)
                sink_D2 = args.eps * eif.second_order_EIF_operator(omega=omega.to(device), 
                                                                   K=O2.T.to(device), 
                                                                   A=A_n.to(device), 
                                                                   P_yax=P_yax.to(device), 
                                                                   P_ax=P_ax.to(device))
                sink_U = stat.U_statistic_from_kernel(M=sink_D2, w=None, weighted=False, exclude_diag=True)
                sink_two_step = sink_one_step + 0.5 * float(sink_U)
                sink_quant_out = stat.chaos_quantile(
                    sink_D2 / (2*n),
                    cutoffs=args.hypothesis_test_cutoffs,
                    t_obs=n * sink_two_step,
                    return_lambdas=False,
                )

                logger.info("theta=%.2f n=%.1f SIM=%.1f", theta, n, j+1)
                logger.info(
                    "MMD: truth=%.6f plugin=%.6f one=%.6f two=%.6f p=%.6g",
                    mmd_true, mmd_plugin, mmd_one_step, mmd_two_step, float(mmd_quant_out["p_value"]),
                )
                logger.info(
                    "Sink: truth=%.6f plugin=%.6f one=%.6f two=%.6f p=%.6g",
                    sink_true, sink_plugin, sink_one_step, sink_two_step, float(sink_quant_out["p_value"]),
                )
                
                # ---------------- Record ----------------
                record = {
                    "theta_idx": int(i),
                    "theta_value": float(theta),
                    "sim_idx": int(j),
                    "n": int(n),
                    "cutoffs": [float(x) for x in args.hypothesis_test_cutoffs],
                    "eps": float(args.eps),
                    # MMD
                    "mmd_true": float(mmd_true),
                    "mmd_plugin": float(mmd_plugin),
                    "mmd_one_step": float(mmd_one_step),
                    "mmd_two_step": float(mmd_two_step),
                    "mmd_D1_mean": float(mmd_D1_mean),
                    "mmd_D1_se": float(mmd_D1_se),
                    "mmd_ci_low": float(mmd_ci_low),
                    "mmd_ci_high": float(mmd_ci_high),
                    "mmd_U_index": float(mmd_U),
                    "mmd_quantiles": io._json_default(mmd_quant_out["quantiles"]),
                    "mmd_p_value": float(mmd_quant_out["p_value"]),
                    # Sinkhorn
                    "sink_true": float(sink_true),
                    "sink_plugin": float(sink_plugin),
                    "sink_one_step": float(sink_one_step),
                    "sink_two_step": float(sink_two_step),
                    "sink_D1_mean": float(sink_D1_mean),
                    "sink_D1_se": float(sink_D1_se),
                    "sink_ci_low": float(sink_ci_low),
                    "sink_ci_high": float(sink_ci_high),
                    "sink_U_index": float(sink_U),
                    "sink_quantiles": io._json_default(sink_quant_out["quantiles"]),
                    "sink_p_value": float(sink_quant_out["p_value"]),
                }

                pending_flush.append(record)
                if args.resume:
                    done.add(key)

                # Periodic flush: JSONL + checkpoint snapshot
                if len(pending_flush) >= args.save_every:
                    
                    io.flush_records_to_csv(pending_flush, csv_path)   
                    pending_flush.clear()

                    io.atomic_save_npz(
                        ckpt_path,
                        run_name=np.array(args.run_name, dtype=object),
                        n_grid=np.array(n_grid, dtype=int),
                        theta_values=theta_values, 
                        sink_true_by_theta=sink_true_by_theta,
                        mmd_true_by_theta=mmd_true_by_theta,

                        # store args + DGP params as object arrays
                        args=np.array(vars(args), dtype=object),
                        dgp_params=np.array(
                            {
                                "gamma0": gamma0,
                                "gamma": gamma,
                                "B": B,
                                "T_base": T_base,
                                "Sigma": Sigma,
                            },
                            dtype=object,
                        ),
                        # for quick resume stats
                        n_rows=np.array(len(done), dtype=int),
                        last_update_utc=np.array(time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), dtype=object),
                    )
                    logger.info("Checkpointed: %s (rows=%d)", ckpt_path, len(done))

                del estimator, P_hat1, P_hat0, P_yax, P1, P0, K, C, Q, tK, O, O2, mmd_D2, sink_D2
                torch.cuda.empty_cache()

        # End theta: flush pending + checkpoint
        if pending_flush:
            io.flush_records_to_csv(pending_flush, csv_path)   
            pending_flush.clear()

        io.atomic_save_npz(
            ckpt_path,
            run_name=np.array(args.run_name, dtype=object),
            n_grid=np.array(n_grid, dtype=int),
            theta_factors=theta_values,
            sink_true_by_theta=sink_true_by_theta,
            mmd_true_by_theta=mmd_true_by_theta,
            args=np.array(vars(args), dtype=object),
            dgp_params=np.array(
                {
                    "gamma0": gamma0,
                    "gamma": gamma,
                    "B": B,
                    "T_base": T_base,
                    "Sigma": Sigma,
                },
                dtype=object,
            ),
            n_rows=np.array(len(done), dtype=int),
            last_update_utc=np.array(time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), dtype=object),
        )
        logger.info("Theta %d done. Checkpointed: %s", i, ckpt_path)

    logger.info("Run complete. Results in: %s", run_dir)
      