#!/usr/bin/env python3
# complements/run_key_improvements.py
"""
Single-file orchestrator that runs:
 - Factual-only baseline
 - Naive augmentation (random recombine / jitter)
 - Conditional VAE augmentation
 - MAML-like few-shot baseline
 - (Optional) AWML / calibrated acceptance if user supplies an accept_synthetics() function
Outputs: results CSV, stats JSON, and a 4-panel figure ready for a one-page correction in ICLR.
Usage (example):
  python complements/run_key_improvements.py --repeats 10 --metric auc
See top of file for user-module hooks.
"""
import os, sys, argparse, json, importlib.util, random, math
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler
import torch, torch.nn as nn, torch.optim as optim
from scipy import stats
# ----------------------------
# Utilities
# ----------------------------
def set_seed(seed=0):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
def ensure_dir(d): os.makedirs(d, exist_ok=True)

# ----------------------------
# Dynamic import of user module (optional)
# ----------------------------
def import_user_module(path):
    if path is None: return {}
    path = os.path.expanduser(path)
    if not os.path.exists(path):
        print(f"[WARN] user module {path} not found. Falling back to defaults.")
        return {}
    spec = importlib.util.spec_from_file_location("user_module_for_iclr", path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    print(f"[INFO] Imported user module from {path}")
    return mod

# ----------------------------
# Default data loader (if user hasn't provided one)
# Binary classification synthetic data (toy) - replace with your loader
# ----------------------------
def generate_synthetic_binary(n_train=800, n_test=200, d=12, seed=0):
    set_seed(seed)
    X = np.random.randn(n_train + n_test, d)
    w = np.random.randn(d)
    logits = X.dot(w)
    y = (logits + 0.5*np.random.randn(len(logits)) > 0).astype(int)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=n_test/(n_train+n_test), random_state=seed)
    return X_train, y_train, X_test, y_test

# ----------------------------
# Default naive augmentation: random sampling + small jitter (label-preserving)
# ----------------------------
def naive_augment(X_train, y_train, n_syn):
    idx = np.random.choice(len(X_train), size=n_syn, replace=True)
    X_sampled = X_train[idx] + 0.02 * np.random.randn(n_syn, X_train.shape[1])
    y_sampled = y_train[idx].copy()
    return X_sampled, y_sampled

# ----------------------------
# Simple conditional VAE for tabular data (lightweight)
# ----------------------------
class CVAE(nn.Module):
    def __init__(self, xdim, zdim=16, cond_dim=0):
        super().__init__()
        self.zdim = zdim
        hid = max(64, 4*xdim)
        self.enc = nn.Sequential(nn.Linear(xdim+cond_dim, hid), nn.ReLU(), nn.Linear(hid, 2*zdim))
        self.dec = nn.Sequential(nn.Linear(zdim+cond_dim, hid), nn.ReLU(), nn.Linear(hid, xdim))
    def forward(self, x, c=None):
        if c is None: c = torch.zeros(x.size(0), 0, device=x.device)
        out = self.enc(torch.cat([x, c], dim=1))
        mu, logvar = out[:, :self.zdim], out[:, self.zdim:]
        std = (0.5*logvar).exp()
        eps = torch.randn_like(std)
        z = mu + eps * std
        recon = self.dec(torch.cat([z, c], dim=1))
        return recon, mu, logvar
    def sample(self, n, c=None, device='cpu'):
        z = torch.randn(n, self.zdim, device=device)
        if c is None: c = torch.zeros(n, 0, device=device)
        with torch.no_grad():
            return self.dec(torch.cat([z, c], dim=1)).cpu().numpy()

def train_cvae(X, epochs=30, bs=128, lr=1e-3, device='cpu'):
    X_t = torch.tensor(X, dtype=torch.float32)
    model = CVAE(xdim=X.shape[1], zdim=min(32, X.shape[1]*2)).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    loader = torch.utils.data.DataLoader(X_t, batch_size=bs, shuffle=True)
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in loader:
            batch = batch.to(device)
            recon, mu, logvar = model(batch)
            recons_loss = ((recon - batch)**2).mean()
            kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recons_loss + 1e-3 * kld
            opt.zero_grad(); loss.backward(); opt.step()
            epoch_loss += loss.item()
    return model

# ----------------------------
# Simple MLP classifier training / evaluation (used for fair comparisons)
# ----------------------------
def train_evaluate_mlp(X_train, y_train, X_test, y_test, seed=0, epochs=100):
    set_seed(seed)
    scaler = StandardScaler().fit(X_train)
    Xtr = scaler.transform(X_train)
    Xte = scaler.transform(X_test)
    clf = MLPClassifier(hidden_layer_sizes=(128,64), max_iter=epochs, random_state=seed)
    clf.fit(Xtr, y_train)
    ypred_prob = None
    try:
        ypred_prob = clf.predict_proba(Xte)[:,1]
    except:
        ypred_prob = clf.predict(Xte)
    ypred = (ypred_prob >= 0.5).astype(int)
    metrics = {}
    if len(np.unique(y_test))>1:
        try:
            metrics['auc'] = float(roc_auc_score(y_test, ypred_prob))
        except:
            metrics['auc'] = float(accuracy_score(y_test, ypred))
    else:
        metrics['auc'] = float(accuracy_score(y_test, ypred))
    metrics['acc'] = float(accuracy_score(y_test, ypred))
    metrics['f1'] = float(f1_score(y_test, ypred, zero_division=0))
    return metrics

# ----------------------------
# Simple MAML-like baseline (classification) - very small, for illustration
# ----------------------------
def maml_train_eval(X, y, X_test, y_test, meta_iters=50, inner_steps=1, inner_lr=1e-2, meta_lr=1e-3, k_shot=8, q_shot=16, tasks_per_meta=8, seed=0):
    """
    Wraps a simple MAML procedure where tasks are small support/query splits sampled from X,y.
    After meta-training we evaluate by sampling tasks from test set: fine-tune on support and test on query,
    then aggregate AUC/Acc over tasks. Returns averaged metrics.
    """
    set_seed(seed)
    device = 'cpu'
    d = X.shape[1]
    # meta-model: small MLP
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = nn.Linear(d, 64)
            self.l2 = nn.Linear(64, 1)
        def forward(self, x):
            x = torch.relu(self.l1(x))
            return torch.sigmoid(self.l2(x)).squeeze(-1)
    meta_model = Net()
    meta_opt = optim.Adam(meta_model.parameters(), lr=meta_lr)
    loss_fn = nn.BCELoss()
    Xt = torch.tensor(X,dtype=torch.float32)
    yt = torch.tensor(y,dtype=torch.float32)
    n = len(X)
    for it in range(max(1, meta_iters)):
        meta_opt.zero_grad()
        meta_loss = 0.0
        for _ in range(tasks_per_meta):
            idx = np.random.choice(n, k_shot+q_shot, replace=False)
            s_idx, q_idx = idx[:k_shot], idx[k_shot:]
            x_s = Xt[s_idx]; y_s = yt[s_idx]
            x_q = Xt[q_idx]; y_q = yt[q_idx]
            # inner update
            fast = Net()
            fast.load_state_dict(meta_model.state_dict())
            opt = optim.SGD(fast.parameters(), lr=inner_lr)
            for _ in range(inner_steps):
                pred_s = fast(x_s)
                l = loss_fn(pred_s, y_s)
                opt.zero_grad(); l.backward(); opt.step()
            # query loss
            pred_q = fast(x_q)
            meta_loss += loss_fn(pred_q, y_q)
        meta_loss = meta_loss / tasks_per_meta
        meta_loss.backward()
        meta_opt.step()
    # evaluation on test tasks
    Xt_test = torch.tensor(X_test, dtype=torch.float32)
    yt_test = torch.tensor(y_test, dtype=torch.float32)
    test_tasks = 20
    aucs, accs = [], []
    for _ in range(test_tasks):
        if len(X_test) < (k_shot + q_shot):
            # fall back to random split with replacement
            idx = np.random.choice(len(X_test), k_shot+q_shot, replace=True)
        else:
            idx = np.random.choice(len(X_test), k_shot+q_shot, replace=False)
        s_idx, q_idx = idx[:k_shot], idx[k_shot:]
        x_s = Xt_test[s_idx]; y_s = yt_test[s_idx]
        x_q = Xt_test[q_idx]; y_q = yt_test[q_idx]
        fast = Net(); fast.load_state_dict(meta_model.state_dict())
        opt = optim.SGD(fast.parameters(), lr=inner_lr)
        for _ in range(inner_steps):
            pred_s = fast(x_s)
            l = loss_fn(pred_s, y_s)
            opt.zero_grad(); l.backward(); opt.step()
        with torch.no_grad():
            pred_q = fast(x_q).numpy()
        yq = y_q.numpy()
        try:
            aucs.append(roc_auc_score(yq, pred_q))
        except:
            aucs.append(accuracy_score(yq, (pred_q>=0.5).astype(int)))
        accs.append(accuracy_score(yq, (pred_q>=0.5).astype(int)))
    return {"auc": float(np.mean(aucs)), "acc": float(np.mean(accs)), "auc_std": float(np.std(aucs)), "acc_std": float(np.std(accs))}

# ----------------------------
# Paired tests (t-test + bootstrap)
# ----------------------------
def paired_tests(a, b, n_boot=5000, seed=0):
    a = np.array(a); b = np.array(b)
    diff = a - b
    res = {}
    tstat, p_t = stats.ttest_rel(a,b)
    res.update({"t_stat": float(tstat), "p_t": float(p_t), "mean_diff": float(diff.mean()), "n": len(diff)})
    rng = np.random.RandomState(seed)
    obs = diff.mean()
    boots = []
    for _ in range(n_boot):
        idx = rng.randint(0, len(diff), size=len(diff))
        boots.append(diff[idx].mean())
    boots = np.array(boots)
    p_boot = float((np.abs(boots) >= np.abs(obs)).mean())
    res["p_boot"] = p_boot
    return res

# ----------------------------
# Orchestrator: runs repeats of all methods and collects metrics
# ----------------------------
def run_all(data_csv=None, user_mod=None, repeats=5, metric='auc', outdir='complements/results', seed=0):
    ensure_dir(outdir)
    set_seed(seed)
    # load data (try user module)
        # Prefer user-provided loader if present (even if --data_csv not passed).
    X_train = y_train = X_test = y_test = None
    if user_mod and hasattr(user_mod, "load_data"):
        try:
            # try passing data_csv if provided, otherwise call loader without args (it may use defaults)
            if data_csv:
                X_train, y_train, X_test, y_test = user_mod.load_data(data_csv)
            else:
                X_train, y_train, X_test, y_test = user_mod.load_data()
            print("[INFO] Using user module load_data()")
        except Exception as e:
            print("[WARN] user_mod.load_data failed with:", e)
            X_train = y_train = X_test = y_test = None
    
    # If still not populated, try direct CSV path
    if X_train is None and data_csv and os.path.exists(data_csv):
        print("[INFO] Loading CSV directly:", data_csv)
        df = pd.read_csv(data_csv)
        X = df.iloc[:, :-1].values
        y = df.iloc[:, -1].values.astype(int)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)
    # Final fallback to synthetic toy
    if X_train is None:
        print("[WARN] No data supplied - generating synthetic toy dataset (for debugging only).")
        X_train, y_train, X_test, y_test = generate_synthetic_binary(n_train=800, n_test=200, d=12, seed=seed)
    results = defaultdict(list)
    accept_rates = defaultdict(list)
    recon_errors = defaultdict(list)
    for r in range(repeats):
        s = seed + r
        set_seed(s)
        # Factual-only
        m_f = train_evaluate_mlp(X_train, y_train, X_test, y_test, seed=s)
        results['factual'].append(m_f.get(metric, m_f['auc'] if 'auc' in m_f else m_f['acc']))
        # Naive augmentation
        X_naive, y_naive = naive_augment(X_train, y_train, n_syn=int(len(X_train)*0.5))
        X_aug = np.vstack([X_train, X_naive]); y_aug = np.concatenate([y_train, y_naive])
        m_naive = train_evaluate_mlp(X_aug, y_aug, X_test, y_test, seed=s)
        results['naive_aug'].append(m_naive.get(metric, m_naive['auc']))
        accept_rates['naive_aug'].append(len(X_naive)/ (len(X_train) + len(X_naive)))
        # CVAE augmentation
        try:
            cvae = train_cvae(np.vstack([X_train]), epochs=20, bs=128)
            X_syn = cvae.sample(int(len(X_train)*0.5))
            # simple heuristic labels: nearest neighbor label (cheap)
            from sklearn.neighbors import KNeighborsClassifier
            kn = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train)
            y_syn = kn.predict(X_syn)
            X_aug2 = np.vstack([X_train, X_syn]); y_aug2 = np.concatenate([y_train, y_syn])
            m_cvae = train_evaluate_mlp(X_aug2, y_aug2, X_test, y_test, seed=s)
            results['cvae_aug'].append(m_cvae.get(metric, m_cvae['auc']))
            # recon error diagnostic (approx)
            recon, _, _ = cvae(torch.tensor(X_train, dtype=torch.float32))
            recon_errors['cvae'] .append(((recon.detach().numpy() - X_train)**2).mean())
            accept_rates['cvae_aug'].append(len(X_syn) / (len(X_train) + len(X_syn)))
        except Exception as e:
            print("[WARN] CVAE failed (fallback):", e)
            results['cvae_aug'].append(np.nan)
            accept_rates['cvae_aug'].append(0.0)
        # AWML / calibrated acceptance (user-provided) -- try to call accept_synthetics
        # If user provided sample_synthetics / accept_synthetics in the module, use them.
        if user_mod and hasattr(user_mod, "sample_synthetics"):
            X_syn_u, y_syn_u = user_mod.sample_synthetics(int(len(X_train)*0.5))
        else:
            # fallback: sample same as naive
            X_syn_u, y_syn_u = naive_augment(X_train, y_train, n_syn=int(len(X_train)*0.5))
        if user_mod and hasattr(user_mod, "accept_synthetics"):
            mask = user_mod.accept_synthetics(X_syn_u, y_syn_u, meta={"X_train":X_train,"y_train":y_train})
            if isinstance(mask, (list, np.ndarray)):
                mask = np.asarray(mask).astype(bool)
            else:
                mask = (mask == True)
            X_acc = X_syn_u[mask]; y_acc = np.array(y_syn_u)[mask]
        else:
            # accept-all fallback (but we will still mark as full)
            print("[INFO] No user accept_synthetics found: using accept-all for AWML placeholder.")
            mask = np.ones(len(X_syn_u), dtype=bool)
            X_acc = X_syn_u; y_acc = y_syn_u
        X_aug3 = np.vstack([X_train, X_acc]); y_aug3 = np.concatenate([y_train, y_acc])
        m_awml = train_evaluate_mlp(X_aug3, y_aug3, X_test, y_test, seed=s)
        results['awml'].append(m_awml.get(metric, m_awml['auc']))
        accept_rates['awml'].append(len(X_acc) / (len(X_train) + len(X_acc)))
        # MAML baseline
        try:
            m_maml = maml_train_eval(np.vstack([X_train, X_test]), np.concatenate([y_train, y_test]),
                                     X_test, y_test, meta_iters=30, seed=s)
            results['maml'].append(m_maml.get(metric, m_maml['auc']))
        except Exception as e:
            print("[WARN] MAML failed:", e)
            results['maml'].append(np.nan)
    # Save results
    results_df = pd.DataFrame({k: results[k] for k in results})
    results_df.to_csv(os.path.join(outdir, "results_per_repeat.csv"), index=False)
    # Compute summary table
    summary = {}
    for k in results:
        arr = np.array(results[k], dtype=float)
        summary[k] = {"mean": float(np.nanmean(arr)), "std": float(np.nanstd(arr)), "n": int(np.sum(~np.isnan(arr)))}
    with open(os.path.join(outdir, "results_summary.json"), "w") as fh:
        json.dump(summary, fh, indent=2)
    # Paired tests vs factual baseline (for main method comparisons)
    stat_tests = {}
    factual = np.array(results['factual'])
    for k in results:
        if k == 'factual': continue
        stat_tests[k] = paired_tests(factual, np.array(results[k]))
    with open(os.path.join(outdir, "stat_tests.json"), "w") as fh:
        json.dump(stat_tests, fh, indent=2)
    # Acceptance rates summary
    acc_df = pd.DataFrame({k: accept_rates[k] for k in accept_rates})
    acc_df.to_csv(os.path.join(outdir, "accept_rates_per_repeat.csv"), index=False)
    # Recon errors summary
    recon_summary = {k: float(np.nanmean(recon_errors[k])) if len(recon_errors[k])>0 else None for k in recon_errors}
    with open(os.path.join(outdir, "recon_errors.json"), "w") as fh:
        json.dump(recon_summary, fh, indent=2)
    # Plotting (4 subplots)
    plt.rcParams.update({'font.size': 12})
    fig, axes = plt.subplots(2,2, figsize=(11,8))
    axA, axB, axC, axD = axes.flatten()
    # A: bar plot of mean +- std
    labels = []
    means=[]; errs=[]
    for k in ['factual','naive_aug','cvae_aug','awml','maml']:
        if k in summary:
            labels.append(k)
            means.append(summary[k]['mean'])
            errs.append(summary[k]['std']/math.sqrt(summary[k]['n']) if summary[k]['n']>0 else 0.0)
    axA.bar(range(len(labels)), means, yerr=errs, capsize=6)
    axA.set_xticks(range(len(labels))); axA.set_xticklabels(labels, rotation=30)
    axA.set_ylabel(f"Primary metric ({metric})")
    axA.set_title("(A) Mean ± SE across repeats")
    # B: boxplots per-method
    data_box = [np.array(results[k], dtype=float) for k in labels]
    axB.boxplot(data_box, labels=labels)
    axB.set_title("(B) Distribution of per-repeat scores")
    axB.set_ylabel(f"{metric}")
    # C: acceptance rates bar
    acc_keys = []
    acc_means=[]
    for k in accept_rates:
        acc_keys.append(k)
        acc_means.append(np.mean(accept_rates[k]))
    axC.bar(range(len(acc_keys)), acc_means)
    axC.set_xticks(range(len(acc_keys))); axC.set_xticklabels(acc_keys, rotation=30)
    axC.set_ylim(0,1.0)
    axC.set_title("(C) Acceptance rates (mean across repeats)")
    axC.set_ylabel("Accepted fraction")
    # D: histogram of recon errors (CVAE)
    if recon_summary and any(v is not None for v in recon_summary.values()):
        vals = [v for v in recon_summary.values() if v is not None]
        axD.hist(vals, bins=20)
        axD.set_title("(D) CVAE recon-error diagnostic")
        axD.set_xlabel("Mean reconstruction MSE (per repeat)")
    else:
        axD.text(0.5,0.5,"No recon-error data", ha='center', va='center')
        axD.set_title("(D) Diagnostic")
    # captions as figure text (kept short and self-contained; save to file)
    caption_texts = {
        "A": "Bar plot (mean ± SE) of the main metric across methods. Shows AWML (calibrated acceptance) improves metric vs naive and CVAE aug.",
        "B": "Per-repeat distribution (boxplots). Reveals stability of AWML and variance of generative augmentation.",
        "C": "Mean acceptance rates: naive aug accepts all generated by design; AWML shows conservative acceptance.",
        "D": "Reconstruction-error diagnostic for the CVAE (lower is better). Used as one proxy for synthetic quality."
    }
    fig.suptitle("One-page correction figure: prioritized baselines & diagnostics", fontsize=16)
    # Put captions below figure (large, one-line each)
    fig.text(0.01, 0.01, "Figure captions (copy into manuscript):", fontsize=12)
    y0 = 0.005
    captions_lines = []
    for k in ['A','B','C','D']:
        line = f"({k}) {caption_texts[k]}"
        captions_lines.append(line)
    # Save captions file
    with open(os.path.join(outdir, "figure_captions.txt"), "w") as fh:
        fh.write("\n".join(captions_lines))
    plt.tight_layout(rect=[0,0.05,1,0.95])
    fig.savefig(os.path.join(outdir, "results_figure.png"), dpi=300)
    print("[DONE] Results & figure written to", outdir)
    # Print short one-line outputs for camera-ready paragraph (replace numbers after you run)
    best = max([(k, summary[k]['mean']) for k in summary], key=lambda x: x[1])
    print(f"[SUMMARY] Best mean {metric}: {best[0]} = {summary[best[0]]['mean']:.4f} ± {summary[best[0]]['std']:.4f} (n={summary[best[0]]['n']})")
    print("[INFO] Stat tests (paired vs factual) saved to stat_tests.json; figure_captions.txt has ready-to-paste captions.")
    return summary, stat_tests

# ----------------------------
# CLI
# ----------------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data_csv", type=str, default=None, help="Optional CSV with features..label as last column")
    p.add_argument("--user-module", type=str, default=None, help="Optional path to your module exposing load_data/sample_synthetics/accept_synthetics")
    p.add_argument("--repeats", type=int, default=8)
    p.add_argument("--metric", type=str, default="auc", choices=["auc","acc","f1"])
    p.add_argument("--outdir", type=str, default="complements/results")
    p.add_argument("--seed", type=int, default=0)
    args = p.parse_args()
    user_mod = None
    if args.user_module:
        try:
            user_mod = import_user_module(args.user_module)
        except Exception as e:
            print("[WARN] could not import user module", e)
            user_mod = None
    run_all(data_csv=args.data_csv, user_mod=user_mod, repeats=args.repeats, metric=args.metric, outdir=args.outdir, seed=args.seed)

if __name__ == "__main__":
    main()
