#!/usr/bin/env python3
# train_superconductor_surrogate.py
import argparse, json, os, random
import numpy as np
import pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from joblib import dump
try:
    from scipy.stats import spearmanr
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False

# ---------------- Utilities ----------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def pick_visible_indices(y, visible_frac=0.8):
    thresh = np.quantile(y, visible_frac)
    return np.where(y <= thresh)[0]

def train_val_split(n, val_frac=0.1, seed=0):
    rng = np.random.RandomState(seed)
    idx = np.arange(n); rng.shuffle(idx)
    n_val = int(round(val_frac * n))
    return idx[n_val:], idx[:n_val]  # train, val

class TabularDataset(Dataset):
    def __init__(self, Xz, yz):
        self.Xz = Xz.astype(np.float32)
        self.yz = yz.astype(np.float32)
    def __len__(self): return self.Xz.shape[0]
    def __getitem__(self, i):
        return torch.from_numpy(self.Xz[i]), torch.from_numpy(self.yz[i])

class SurrogateMLP(nn.Module):
    def __init__(self, d_in=86, d_hidden=2048, d_out=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden), nn.LeakyReLU(),
            nn.Linear(d_hidden, d_hidden), nn.LeakyReLU(),
            nn.Linear(d_hidden, d_out)
        )
    def forward(self, x): return self.net(x)

@torch.no_grad()
def evaluate_denorm(model, X_vis, y_vis, x_scaler, y_scaler, device, bs=1024):
    model.eval()
    Xz = x_scaler.transform(X_vis).astype(np.float32)
    preds = []
    for i in range(0, len(Xz), bs):
        xb = torch.from_numpy(Xz[i:i+bs]).to(device)
        pz = model(xb).cpu().numpy()
        preds.append(pz)
    pz = np.vstack(preds)
    yhat = y_scaler.inverse_transform(pz).ravel()
    y_true = y_vis.ravel()
    out = {
        "mse": float(mean_squared_error(y_true, yhat)),
        "r2":  float(r2_score(y_true, yhat)),
    }
    if HAVE_SCIPY:
        out["spearman"] = float(spearmanr(y_true, yhat).correlation)
    return out

def save_ckpt(path, model, opt, epoch, best_val):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "epoch": epoch,
        "state_dict": model.state_dict(),
        "optimizer": opt.state_dict(),
        "best_val_mse_z": best_val,
    }, path)

# ---------------- Train ----------------
def main():
    ap = argparse.ArgumentParser("Train Design-Bench-style surrogate on Superconductor.")
    ap.add_argument("--csv", type=str, default="unique_m.csv")
    ap.add_argument("--visible-frac", type=float, default=1.0)
    ap.add_argument("--val-frac", type=float, default=0.1)
    ap.add_argument("--batch-size", type=int, default=512)
    ap.add_argument("--epochs", type=int, default=10000)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--weight-decay", type=float, default=0.0)
    ap.add_argument("--hidden", type=int, default=2048)
    ap.add_argument("--patience", type=int, default=3000)
    ap.add_argument("--save-every", type=int, default=10, help="Checkpoint every N epochs")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--outdir", type=str, default="runs")
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    os.makedirs(os.path.join(args.outdir, "checkpoints"), exist_ok=True)
    set_seed(args.seed)

    # Load data
    df = pd.read_csv(args.csv)
    elem_cols = [c for c in df.columns if c not in ("critical_temp","material")]
    X = df[elem_cols].to_numpy(np.float32)
    y = df["critical_temp"].to_numpy(np.float32).reshape(-1,1)

    # Visible subset = bottom visible_frac by Tc
    vis_idx = pick_visible_indices(y.ravel(), args.visible_frac)
    X_vis, y_vis = X[vis_idx], y[vis_idx]

    # Train/val split on visible
    tr_idx, va_idx = train_val_split(len(X_vis), args.val_frac, args.seed)
    X_tr, y_tr = X_vis[tr_idx], y_vis[tr_idx]
    X_va, y_va = X_vis[va_idx], y_vis[va_idx]

    # Per-dim bounds from VISIBLE data (used later by optimizer)
    x_min = X_vis.min(axis=0).tolist()
    x_max = X_vis.max(axis=0).tolist()
    best_visible_tc = float(y_vis.max())

    # Fit scalers on TRAIN (visible)
    x_scaler = StandardScaler().fit(X_tr)
    y_scaler = StandardScaler().fit(y_tr)

    # Transform
    Xz_tr = x_scaler.transform(X_tr).astype(np.float32)
    yz_tr = y_scaler.transform(y_tr).astype(np.float32)
    Xz_va = x_scaler.transform(X_va).astype(np.float32)
    yz_va = y_scaler.transform(y_va).astype(np.float32)

    # Datasets & loaders
    dl_tr = DataLoader(TabularDataset(Xz_tr, yz_tr), batch_size=args.batch_size, shuffle=True)
    dl_va = DataLoader(TabularDataset(Xz_va, yz_va), batch_size=args.batch_size, shuffle=False)

    # Model/opt
    device = torch.device(args.device)
    model = SurrogateMLP(d_in=len(elem_cols), d_hidden=args.hidden, d_out=1).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    loss_fn = nn.MSELoss()

    best_val = float("inf")
    best_state = None
    patience_left = args.patience

    # Training loop
    for ep in range(1, args.epochs + 1):
        model.train()
        tr_losses = []
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss = loss_fn(pred, yb)
            opt.zero_grad(); loss.backward(); opt.step()
            tr_losses.append(loss.item())

        # Validate
        model.eval()
        val_losses = []
        with torch.no_grad():
            for xb, yb in dl_va:
                xb, yb = xb.to(device), yb.to(device)
                val_losses.append(loss_fn(model(xb), yb).item())
        mean_tr = float(np.mean(tr_losses)) if tr_losses else float('nan')
        mean_val = float(np.mean(val_losses)) if val_losses else float('nan')
        if ep % 10 == 0 or ep == 1:
            print(f"Epoch {ep:4d} | train MSE(z): {mean_tr:.5f} | val MSE(z): {mean_val:.5f}")

        # Save periodic checkpoint
        if ep % args.save_every == 0:
            save_ckpt(os.path.join(args.outdir, "checkpoints", f"epoch_{ep:04d}.pt"),
                      model, opt, ep, best_val)

        # Early stopping on best val
        if mean_val < best_val - 1e-6:
            best_val = mean_val
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_left = args.patience
            # Also save a "best" checkpoint immediately
            save_ckpt(os.path.join(args.outdir, "checkpoints", f"best.pt"),
                      model, opt, ep, best_val)
        else:
            patience_left -= 1
            if patience_left <= 0:
                print(f"Early stopping at epoch {ep}. Best val MSE(z): {best_val:.6f}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # Final metrics (original Tc units) on visible val
    metrics = evaluate_denorm(model, X_va, y_va, x_scaler, y_scaler, device)
    print("\nValidation metrics (original Tc space):")
    print(json.dumps(metrics, indent=2))

    # Persist artifacts
    torch.save({"state_dict": model.state_dict(),
                "d_in": len(elem_cols),
                "d_hidden": args.hidden},
               os.path.join(args.outdir, "surrogate.pt"))
    dump(x_scaler, os.path.join(args.outdir, "x_scaler.joblib"))
    dump(y_scaler, os.path.join(args.outdir, "y_scaler.joblib"))
    with open(os.path.join(args.outdir, "elem_cols.txt"), "w") as f:
        for c in elem_cols: f.write(c + "\n")
    with open(os.path.join(args.outdir, "bounds.json"), "w") as f:
        json.dump({"x_min": x_min, "x_max": x_max,
                   "best_visible_tc": best_visible_tc}, f, indent=2)
    with open(os.path.join(args.outdir, "train_config.json"), "w") as f:
        json.dump(vars(args), f, indent=2)

    print(f"\nSaved model+scalers+bounds to: {args.outdir}")

if __name__ == "__main__":
    main()
