from __future__ import annotations
# Import modules
import sys
sys.path.append(".")
sys.path.append("../")
sys.path.append("../../")
import os
import numpy as np
import argparse
import time
from typing import Any, Dict, List
from tqdm import tqdm

# Torch
import torch
from torchvision.datasets import PCAM
from torchvision import transforms
from torch.utils.data import Subset

# Scratch modules
import src.utils.io as io
import src.utils.mmd as mmd
import src.methods.propensity_models as propensity_models
from src.datasets.generate_pcam_data import (CausalPCAMParams, simulate_causal_pcam_observed_tensors)
import src.methods.sinkhorn_div as sinkhorn
from src.methods.NF_conditional_dist_est import (
    estimate_P_matrix,
    ConditionalFlowEstimator,
)
import src.utils.stat as stat
import src.methods.eif as eif

import torchvision.models as models


# ------------------------------ CLI ------------------------------

def parse_arguments() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()

    # ---------------- System / output ----------------
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--n_sims", type=int, default=1000)
    parser.add_argument("--save_every", type=int, default=5)
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--results_tag", type=str, default="PCAM_P1_centered")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument( "--run_name", type=str, required=True, help="Unique name for this run/process.",)

    # ---------------- Data ----------------
    parser.add_argument("--d_x", type=int, default=3)
    parser.add_argument("--pca_dim", type=int, default=10)

    # ------------ XGB propensity hyperparams -----------
    parser.add_argument("--xgb_max_depth", type=int, default=3)
    parser.add_argument("--xgb_min_child_weight", type=float, default=1.0)
    parser.add_argument("--xgb_gamma", type=float, default=0.0)
    parser.add_argument("--xgb_subsample", type=float, default=0.8)
    parser.add_argument("--xgb_colsample_bytree", type=float, default=0.8)
    parser.add_argument("--xgb_learning_rate", type=float, default=0.05)
    parser.add_argument("--xgb_n_estimators", type=int, default=400)
    parser.add_argument("--xgb_reg_lambda", type=float, default=1.0)
    parser.add_argument("--xgb_reg_alpha", type=float, default=0.0)
    parser.add_argument("--xgb_test_size", type=float, default=0.2)
    parser.add_argument("--xgb_threshold", type=float, default=0.5)
    parser.add_argument("--xgb_n_jobs", type=int, default=-1)
    parser.add_argument("--xgb_tree_method", type=str, default="hist")

    # ---------------- NF ----------------
    parser.add_argument("--nf_lr", type=float, default=1e-3)
    parser.add_argument("--nf_hidden_dim", type=int, default=32)
    parser.add_argument("--nf_n_layers", type=int, default=4)
    parser.add_argument("--nf_batch_size", type=int, default=256)
    parser.add_argument("--nf_n_epochs", type=int, default=200)

    # ---------------- OT / kernel bandwidth ----------------
    parser.add_argument("--eps", type=float, default=250.0)

    # --------- Hypothesis test ---------
    parser.add_argument(
        "--hypothesis_test_cutoffs",
        type=float,
        nargs="+",
        default=[0.90, 0.95, 0.99],
        help="List of quantile cutoffs in [0,1]. Ex: --hypothesis_test_cutoffs 0.9 0.95 0.99",
    )

    # ---------------- grid ----------------
    parser.add_argument("--n_grid", type=int, nargs="+", default=[5000], help="Sample sizes. Ex: --n_grid 500 1000 2000")
    parser.add_argument("--theta_values", type=float, nargs="+", required=True, help="List of theta values to run. Ex: --theta_values 0.0 0.1 0.2",)

    return parser

def get_pcam_datasets():
    
    base_dataset = PCAM(
    root=io.get_data_dir(),
    split="train",
    transform=transforms.ToTensor(),
    download=True)

    cache_dir = os.path.join(io.get_data_dir(), "pcam_cache")
    os.makedirs(cache_dir, exist_ok=True)

    idx0_path = os.path.join(cache_dir, "idx_0.pt")
    idx1_path = os.path.join(cache_dir, "idx_1.pt")

    if not (os.path.exists(idx0_path) and os.path.exists(idx1_path)):
        print("Building PCAM index cache...")

        idx_0, idx_1 = [], []
        for i in range(len(base_dataset)):
            _, y = base_dataset[i]
            if y == 0:
                idx_0.append(i)
            else:
                idx_1.append(i)

        torch.save(torch.tensor(idx_0, dtype=torch.long), idx0_path)
        torch.save(torch.tensor(idx_1, dtype=torch.long), idx1_path)

        print("Cache saved.")
    else:

        print("PCAM index cache found.")
        idx_1 = torch.load(idx1_path)
        idx_0 = torch.load(idx0_path)
        
    dataset_1 = Subset(base_dataset, idx_1)
    dataset_0 = Subset(base_dataset, idx_0)

    return dataset_1, dataset_0    

@torch.no_grad()
def embed_images_batched(resnet, Y_cpu, batch_size=128, device="cuda"):
    device = torch.device(device)
    Z_chunks = []

    for i in range(0, Y_cpu.size(0), batch_size):
        batch = Y_cpu[i:i+batch_size].to(device, non_blocking=True)
        z = resnet(batch)          # (B, 512)
        Z_chunks.append(z.cpu())  
        del batch, z

    return torch.cat(Z_chunks, dim=0)

if __name__ == '__main__':

    parser = parse_arguments()
    args = parser.parse_args()

    # Reproducibility
    io.set_reproducibility(args.seed)
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32

    # Output layout
    root_results_dir = os.path.join(io.get_results_dir(), args.results_tag)
    io.ensure_dir(root_results_dir)
    run_dir = os.path.join(root_results_dir, args.run_name)
    io.ensure_dir(run_dir)

    # Logger
    log_path = os.path.join(run_dir, "run.log")
    logger = io.configure_logger(log_path, args.verbose)
    logger.info("Starting run: %s", args.run_name)
    logger.info("Device: %s", device)

    # Resume
    csv_path = os.path.join(run_dir, "results.csv")
    if os.path.exists(csv_path):
        raise RuntimeError(f"{csv_path} already exists. Choose a new --run_name or delete the directory.")
    done = set()
    if args.resume:
        done = io.load_done_keys_from_csv(csv_path)
        logger.info("Resume enabled. Found %d completed (theta, sim, n) triples in %s", len(done), csv_path)
    else:
        logger.info("Resume disabled.")

    # Files
    ckpt_path = os.path.join(run_dir, "checkpoint.npz")        # periodic snapshot
    meta_path = os.path.join(run_dir, "meta.json")             # run metadata

    # Persist metadata (atomic)
    meta = {
        "run_name": args.run_name,
        "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "args": vars(args),
        "device": str(device),
        "torch_version": torch.__version__,
        "numpy_version": np.__version__,
    }
    io.atomic_write_bytes(meta_path, io.stable_json_dumps(meta).encode("utf-8"))

    # Sample sizes
    n_grid = [int(x) for x in args.n_grid]
    theta_values = np.array(args.theta_values, dtype=np.float32)

    # ---------------- Load PCAM dataset ----------------
    dataset_1, dataset_0 = get_pcam_datasets()

    # ---------------- Load Resnet model ----------------
    resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
    resnet.fc = torch.nn.Identity()

    # ---------------- Data-generating parameters ----------------
    gamma0 = 0.0
    pi0 = 4.0

    # Use a single RNG for parameters (fixed)
    rng_params = np.random.default_rng(args.seed)
    gamma = rng_params.normal(0.0, 1.0, size=(args.d_x)) # treatment assignment
    pi_beta = 0.1 * np.ones(args.d_x) # baseline disease assignment
    q_beta = 0.01 * np.ones(args.d_x) # probability of cure

    # Propensity config (static across runs)
    xgb_cfg = propensity_models.build_xgb_cfg(args)

    pending_flush: List[Dict[str, Any]] = []

    for i, theta in enumerate(theta_values):

        theta = float(theta)
        q0 = theta

        params = CausalPCAMParams(
                    d_x=args.d_x,
                    gamma0=gamma0,
                    gamma=gamma,
                    pi0=pi0,
                    pi_beta=pi_beta,
                    q0=q0,
                    q_beta=q_beta,
                )
            

        # ---- Sim loop ----
        sim_loop = tqdm(range(args.n_sims), desc="Simulations", position=1, file=sys.stderr, leave=False)
        for j in sim_loop:

            # IMPORTANT: unique RNG per (theta, sim) so sims are not identical across j
            rng_data = np.random.default_rng(args.seed + 1000 * i + j)

            # Generate one dataset at max n (prefix used for each n)
            X_all, A_all, Y_all = simulate_causal_pcam_observed_tensors(
                n=int(2*n_grid[-1]),
                dataset_0=dataset_0,
                dataset_1=dataset_1,
                params=params,
                rng=rng_data,
                return_propensity=False,
                device=torch.device("cpu")
                )

            # get image embeddings
            resnet.eval()
            Y_all = embed_images_batched(resnet, Y_all, batch_size=512, device=device)  

            # get pca
            Y_all = stat.pca_torch(Y_all, args.pca_dim)
            d_y = Y_all.shape[-1]

            for n in n_grid:

                key = (int(i), int(j), int(n))
                if args.resume and key in done:
                    continue

                n = int(n)
                X_hat, X_n = X_all[:n], X_all[-n:]
                A_hat, A_n = A_all[:n], A_all[-n:]
                Y_hat, Y_n = Y_all[:n], Y_all[-n:]

                # Cost / kernel
                C = sinkhorn.cost_matrix_on_atoms(Y_n, metric="sqeuclidean", normalize=False)
                heuristic_eps = sinkhorn.median_pairwise(C)**2
                print(f"Heuristic eps: {heuristic_eps}")
                K = mmd.compute_rbf_kernel_from_cost(C, eps=args.eps)

                # ---- Propensity A|X ----

                propensity_model = propensity_models.xgb_propensity_model(X_hat, A_hat, xgb_cfg)  # expected (n,2) or similar
                P_ax_np = propensity_model.predict_proba(X_n)
                P_ax = torch.tensor(P_ax_np, dtype=dtype).permute(1,0)[[1,0], :]
                e_scores = P_ax[0]
                e_scores = torch.clamp(e_scores, 1e-8, 1.0 - 1e-8)
                
                # ---- Conditional NF for Y|X,A ----
                estimator = ConditionalFlowEstimator(
                    y_dim=d_y,
                    x_dim=args.d_x,
                    device=device,
                    hidden_dim=args.nf_hidden_dim,
                    n_layers=args.nf_n_layers,
                    lr=args.nf_lr,
                    verbose=False
                )
                estimator.fit(
                        X_hat,
                        A_hat,
                        Y_hat,
                        batch_size=args.nf_batch_size,
                        n_epochs=args.nf_n_epochs,
                    )

                # ---- Estimate P(Y | X, A=a) on atoms Y ----
                ones_a = torch.ones((X_n.shape[0], 1), dtype=dtype)
                zeros_a = torch.zeros((X_n.shape[0], 1), dtype=dtype)

                estimator.model.eval()
                with torch.inference_mode():
                    P_hat1 = estimate_P_matrix(estimator.model, X_n, ones_a, Y_n, device=device, batch_size=512, atom_chunk_size=1024)   # (X, Y)
                    P_hat0 = estimate_P_matrix(estimator.model, X_n, zeros_a, Y_n, device=device, batch_size=512, atom_chunk_size=1024)  # (X, Y)

                # Collapse over X to get marginals over atoms Y
                P1 = sinkhorn._validate_histogram(torch.mean(P_hat1, dim=0), name="P1")
                P0 = sinkhorn._validate_histogram(torch.mean(P_hat0, dim=0), name="P0")

                # P_yax expected shape (Y, A, X) in your EIF functions
                P_yax = torch.stack((P_hat1, P_hat0), dim=1).permute(2, 1, 0).to(dtype=dtype)  # (Y, A, X)

                # ---------------- MMD (plugin + EIF) ----------------
                with torch.no_grad():
                    deltaP = (P1 - P0).reshape(-1, 1)
                    mmd_u1 = K @ deltaP
                    mmd_u2 = -mmd_u1
                    mmd_plugin = float((deltaP.T @ mmd_u1).item()) / 2

                # First-order EIF (BUGFIX: u2 was incorrectly set to mmd_D1 in original)
                mmd_D1 = eif.first_order_EIF(
                    e=e_scores.to(device),
                    A=A_n.to(device),
                    u1=mmd_u1.to(device),
                    u2=mmd_u2.to(device),
                    P_yax=P_yax.to(device),
                )
                mmd_D1_mean = float(torch.mean(mmd_D1).item())
                mmd_D1_var = float(torch.var(mmd_D1, unbiased=False).item())
                mmd_D1_se = float(np.sqrt(mmd_D1_var / n))
                mmd_one_step = mmd_plugin + mmd_D1_mean
                mmd_ci_low = mmd_one_step - 1.96 * mmd_D1_se
                mmd_ci_high = mmd_one_step + 1.96 * mmd_D1_se

                # Second-order EIF
                omega = torch.vstack([1.0 / e_scores, -1.0 / (1.0 - e_scores)]).to(device=device, dtype=dtype)  # (2,n)
                O = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                  K=K.to(device), 
                                                  A=A_n.to(device), 
                                                  P_yax=P_yax.to(device), 
                                                  P_ax=P_ax.to(device))  # (n,n)
                mmd_D2 = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                       K=O.T.to(device), 
                                                       A=A_n.to(device), 
                                                       P_yax=P_yax.to(device), 
                                                       P_ax=P_ax.to(device))        # (n,n)
                mmd_U = stat.U_statistic_from_kernel(M=mmd_D2, w=None, weighted=False, exclude_diag=True)
                mmd_two_step = mmd_one_step + 0.5 * float(mmd_U)
                mmd_quant_out = stat.chaos_quantile(
                    mmd_D2 / (2*n),
                    cutoffs=args.hypothesis_test_cutoffs,
                    t_obs=n * mmd_two_step,
                    return_lambdas=False,
                )

                # ---------------- Sinkhorn divergence (plugin + EIF) ----------------

                ot_out = sinkhorn.sinkhorn_divergence_same_atoms(a=P1.to(device), 
                                                                 b=P0.to(device), 
                                                                 X=Y_n.to(device), 
                                                                 eps=args.eps)
                sink_plugin = float(ot_out["S"].item() if torch.is_tensor(ot_out["S"]) else ot_out["S"])

                sink_u1 = ot_out["potentials"]["ab"][0] - ot_out["potentials"]["aa"][0]
                sink_u2 = ot_out["potentials"]["ab"][1] - ot_out["potentials"]["bb"][1]

                sink_D1 = eif.first_order_EIF(
                    e=e_scores.to(device),
                    A=A_n.to(device),
                    u1=sink_u1.to(device),
                    u2=sink_u2.to(device),
                    P_yax=P_yax.to(device),
                )
                sink_D1_mean = float(torch.mean(sink_D1).item())
                sink_D1_var = float(torch.var(sink_D1, unbiased=False).item())
                sink_D1_se = float(np.sqrt(sink_D1_var / n))
                sink_one_step = sink_plugin + sink_D1_mean
                sink_ci_low = sink_one_step - 1.96 * sink_D1_se
                sink_ci_high = sink_one_step + 1.96 * sink_D1_se

                # Second-order Sinkhorn EIF
                f = ot_out["potentials"]["bb"][0].reshape(-1, 1)
                g = ot_out["potentials"]["bb"][1].reshape(1, -1)
                Q = torch.exp((f + g - C.to(device)) / args.eps).to(dtype=dtype, device=device)

                omega = torch.vstack([1.0 / e_scores, -1.0 / (1.0 - e_scores)]).to(device=device, dtype=dtype)  
                tK = sinkhorn.stable_hadamard_M_from_density_Q(Q=Q, w=P1, tau=1e-6, ridge=1e-5)
                O2 = eif.second_order_EIF_operator(omega=omega.to(device), 
                                                   K=Q.to(device), 
                                                   A=A_n.to(device), 
                                                   P_yax=P_yax.to(device), 
                                                   P_ax=P_ax.to(device))  # (n,n)
                sink_D2 = args.eps * eif.second_order_EIF_operator(omega=omega.to(device), 
                                                                   K=O2.T.to(device), 
                                                                   A=A_n.to(device), 
                                                                   P_yax=P_yax.to(device), 
                                                                   P_ax=P_ax.to(device))
                sink_U = stat.U_statistic_from_kernel(M=sink_D2, w=None, weighted=False, exclude_diag=True)
                sink_two_step = sink_one_step + 0.5 * float(sink_U)
                sink_quant_out = stat.chaos_quantile(
                    sink_D2 / (2*n),
                    cutoffs=args.hypothesis_test_cutoffs,
                    t_obs=n * sink_two_step,
                    return_lambdas=False,
                )

                logger.info("theta=%.2f n=%.1f SIM=%.1f", theta, n, j+1)
                logger.info(
                    "MMD: plugin=%.6f one=%.6f two=%.6f p=%.6g",
                    mmd_plugin, mmd_one_step, mmd_two_step, float(mmd_quant_out["p_value"]),
                )
                logger.info(
                    "Sink: plugin=%.6f one=%.6f two=%.6f p=%.6g",
                    sink_plugin, sink_one_step, sink_two_step, float(sink_quant_out["p_value"]),
                )
                
                # ---------------- Record ----------------
                record = {
                    "theta_idx": int(i),
                    "theta_value": float(theta),
                    "sim_idx": int(j),
                    "n": int(n),
                    "cutoffs": [float(x) for x in args.hypothesis_test_cutoffs],
                    "eps": float(args.eps),
                    # MMD
                    "mmd_plugin": float(mmd_plugin),
                    "mmd_one_step": float(mmd_one_step),
                    "mmd_two_step": float(mmd_two_step),
                    "mmd_D1_mean": float(mmd_D1_mean),
                    "mmd_D1_se": float(mmd_D1_se),
                    "mmd_ci_low": float(mmd_ci_low),
                    "mmd_ci_high": float(mmd_ci_high),
                    "mmd_U_index": float(mmd_U),
                    "mmd_quantiles": io._json_default(mmd_quant_out["quantiles"]),
                    "mmd_p_value": float(mmd_quant_out["p_value"]),
                    # Sinkhorn
                    "sink_plugin": float(sink_plugin),
                    "sink_one_step": float(sink_one_step),
                    "sink_two_step": float(sink_two_step),
                    "sink_D1_mean": float(sink_D1_mean),
                    "sink_D1_se": float(sink_D1_se),
                    "sink_ci_low": float(sink_ci_low),
                    "sink_ci_high": float(sink_ci_high),
                    "sink_U_index": float(sink_U),
                    "sink_quantiles": io._json_default(sink_quant_out["quantiles"]),
                    "sink_p_value": float(sink_quant_out["p_value"]),
                }

                pending_flush.append(record)
                if args.resume:
                    done.add(key)

                # Periodic flush: JSONL + checkpoint snapshot
                if len(pending_flush) >= args.save_every:
                    
                    io.flush_records_to_csv(pending_flush, csv_path)   
                    pending_flush.clear()

                    io.atomic_save_npz(
                        ckpt_path,
                        run_name=np.array(args.run_name, dtype=object),
                        n_grid=np.array(n_grid, dtype=int),
                        theta_values=theta_values, 

                        # store args + DGP params as object arrays
                        args=np.array(vars(args), dtype=object),
                        dgp_params=np.array(
                            {
                                "gamma0": gamma0,
                                "gamma": gamma,
                                "pi0": pi0,
                                "pi_beta": pi_beta,
                                "q0": q0,
                                "q_beta": q_beta
                            },
                            dtype=object,
                        ),
                        # for quick resume stats
                        n_rows=np.array(len(done), dtype=int),
                        last_update_utc=np.array(time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), dtype=object),
                    )
                    logger.info("Checkpointed: %s (rows=%d)", ckpt_path, len(done))

                del estimator, P_hat1, P_hat0, P_yax, P1, P0, K, C, Q, O, O2, mmd_D2, sink_D2
                torch.cuda.empty_cache()

        # End theta: flush pending + checkpoint
        if pending_flush:
            io.flush_records_to_csv(pending_flush, csv_path)   
            pending_flush.clear()

        io.atomic_save_npz(
            ckpt_path,
            run_name=np.array(args.run_name, dtype=object),
            n_grid=np.array(n_grid, dtype=int),
            theta_factors=theta_values,
            args=np.array(vars(args), dtype=object),
            dgp_params=np.array(
                {
                    "gamma0": gamma0,
                    "gamma": gamma,
                    "pi0": pi0,
                    "pi_beta": pi_beta,
                    "q0": q0,
                    "q_beta": q_beta
                },
                dtype=object,
            ),
            n_rows=np.array(len(done), dtype=int),
            last_update_utc=np.array(time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), dtype=object),
        )
        logger.info("Theta %d done. Checkpointed: %s", i, ckpt_path)

    logger.info("Run complete. Results in: %s", run_dir)
      