#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
XGBoost fusion baseline: ModernBERT CLS embeddings + structured metadata.
-----------------------------------------------------------------------------

Run:
CUDA_VISIBLE_DEVICES=0 python scripts/xgb_fusion.py \
  --split_dir /path/to/outputs/splits_with_feats \
  --txt_dir   /path/to/data/raw/txt \
  --out_dir   checkpoints/xgb_fusion \
  --batch     8
"""
import json, argparse, itertools, joblib
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.preprocessing import StandardScaler, OneHotEncoder, label_binarize
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
from xgboost import XGBClassifier

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel


# --------------------------------------------------------------------------- #
# 1. CLI
# --------------------------------------------------------------------------- #
def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split_dir", required=True, type=Path,
                    help="Directory containing *_full.csv files with metadata.")
    ap.add_argument("--txt_dir",   required=True, type=Path,
                    help="Directory with raw *.txt FDA label files.")
    ap.add_argument("--out_dir",   required=True, type=Path)
    ap.add_argument("--batch",     default=8, type=int,
                    help="Batch size for ModernBERT encoding.")
    return ap.parse_args()


# --------------------------------------------------------------------------- #
# 2. Embedding Helper
# --------------------------------------------------------------------------- #
class TxtDataset(Dataset):
    def __init__(self, paths: List[Path]):
        self.paths = paths
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        text = self.paths[idx].read_text(errors="ignore")
        return text

@torch.no_grad()
def cls_embeddings(encoder, tokenizer, paths: List[Path], batch: int = 8) -> np.ndarray:
    ds = TxtDataset(paths)
    loader = DataLoader(ds, batch_size=batch, shuffle=False)
    all_vecs = []
    encoder.eval()
    for texts in tqdm(loader, desc=f"CLS {len(paths)}"):
        enc = tokenizer(texts, return_tensors="pt", padding=True,
                        truncation=True, max_length=8192)
        enc = {k: v.cuda() for k, v in enc.items()}
        out = encoder(**enc)
        cls = out.last_hidden_state[:, 0].cpu().numpy()   # (B, hidden)
        all_vecs.append(cls)
    return np.vstack(all_vecs)                            # (N, hidden)


# --------------------------------------------------------------------------- #
# 3. Build metadata matrix
# --------------------------------------------------------------------------- #
NUMERIC_COLS = [
    "total_studies", "age_min", "age_max",
    "Number of Centers", "Number of Countries", "Patients Enrolled",
    "Patients Analyzed", "Total # of Hispanic/Latino", "Total # of Non-Hispanic/Non-Latino",
    "Total #  of Unknown Ethnicity", "Total #  of Asian",
    "Total #  of Black", "Total #  of White",
]

BOOL_COLS = [
    'Efficacy', 'Safety', 'Pharmacokinetic', 'Pharmacodynamic',
    'Tolerability', 'Other_Type', 'Randomized_DoubleBlind',
    'Randomized_SingleBlind', 'Open_Label', 'Placebo_Control',
    'Active_Comparator', 'Dose_Escalation', 'Population_PK',
    'Other_Design', 'Studied in Neonates', 'Indicated in Neonates',
]

CAT_COLS = [
    'Type of Legislation', 'Therapeutic Category',
    'Dosage Form(s)', 'Route(s) of Administration',
]

def build_meta(df: pd.DataFrame,
               scaler: StandardScaler = None,
               ohe: OneHotEncoder = None,
               fit: bool = False) -> Tuple[np.ndarray, StandardScaler, OneHotEncoder]:
    # ----- numeric -----
    num = df[NUMERIC_COLS].replace("", np.nan).astype(float).fillna(0).values
    if fit:
        scaler = StandardScaler()
        num = scaler.fit_transform(num)
    else:
        num = scaler.transform(num)

    # ----- boolean -----
    missing_bool = [c for c in BOOL_COLS if c not in df.columns]
    df[missing_bool] = 0
    bool_ = df[BOOL_COLS].replace("", 0).astype(int).values

    # ----- categorical -----
    cat = df[CAT_COLS].fillna("UNK").astype(str)
    if fit:
        ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False, dtype=np.float32)
        cat_mat = ohe.fit_transform(cat)
    else:
        cat_mat = ohe.transform(cat)

    # ---- fuse ----
    return np.hstack([num, bool_, cat_mat]), scaler, ohe


# --------------------------------------------------------------------------- #
# 4. Main
# --------------------------------------------------------------------------- #
def main():
    args = parse_args()
    args.out_dir.mkdir(parents=True, exist_ok=True)

    # ---- load splits ----
    df_train = pd.read_csv(args.split_dir / "train_full.csv", dtype=str).fillna("")
    df_dev   = pd.read_csv(args.split_dir / "dev_full.csv",   dtype=str).fillna("")
    df_test  = pd.read_csv(args.split_dir / "test_full.csv",  dtype=str).fillna("")

    label2id = {"NotExtrapolated": 0, "Partial": 1, "Full": 2, "Unlabeled": 3}
    for d in (df_train, df_dev, df_test):
        d["label"] = d["label"].str.strip().map(label2id).astype(int)

    y_train = df_train["label"].values
    y_dev   = df_dev["label"].values
    y_test  = df_test["label"].values

    # ---- ModernBERT encoder (frozen) ----
    tokenizer = AutoTokenizer.from_pretrained("Simonlee711/Clinical_ModernBERT")
    encoder   = AutoModel.from_pretrained("Simonlee711/Clinical_ModernBERT").cuda()
    for p in encoder.parameters():       # freeze
        p.requires_grad = False

    # ---- CLS embeddings ----
    X_tr_text  = cls_embeddings(encoder, tokenizer,
                    [args.txt_dir / f"{cid}.txt" for cid in df_train["canon_id"]],
                    batch=args.batch)
    X_dev_text = cls_embeddings(encoder, tokenizer,
                    [args.txt_dir / f"{cid}.txt" for cid in df_dev["canon_id"]],
                    batch=args.batch)
    X_te_text  = cls_embeddings(encoder, tokenizer,
                    [args.txt_dir / f"{cid}.txt" for cid in df_test["canon_id"]],
                    batch=args.batch)

    # ---- metadata ----
    X_tr_meta, scaler, ohe = build_meta(df_train, fit=True)
    X_dev_meta, _, _       = build_meta(df_dev,   scaler, ohe, fit=False)
    X_te_meta, _, _        = build_meta(df_test,  scaler, ohe, fit=False)

    # ---- fuse ----
    X_train = np.hstack([X_tr_text,  X_tr_meta])
    X_dev   = np.hstack([X_dev_text, X_dev_meta])
    X_test  = np.hstack([X_te_text,  X_te_meta])

    print("Fused dims →", X_train.shape, X_test.shape)

    # ---- XGBoost ----
    clf = XGBClassifier(
        objective="multi:softprob", num_class=4,
        eval_metric="mlogloss", use_label_encoder=False,
        max_depth=6, learning_rate=0.1, n_estimators=300,
        subsample=0.9, colsample_bytree=0.9, random_state=42,
    )
    clf.fit(X_train, y_train,
            eval_set=[(X_dev, y_dev)],
            verbose=50)

    # ---- evaluation ----
    y_pred = clf.predict(X_test)
    y_prob = clf.predict_proba(X_test)
    y_test_1hot = label_binarize(y_test, classes=[0, 1, 2, 3])

    metrics = {
        "accuracy"    : accuracy_score(y_test, y_pred),
        "macro_f1"    : f1_score(y_test, y_pred, average="macro"),
        "weighted_f1" : f1_score(y_test, y_pred, average="weighted"),
        "macro_auc"   : roc_auc_score(y_test_1hot, y_prob, average="macro", multi_class="ovr")
    }

    print("\n-- TEST REPORT --")
    print(classification_report(y_test, y_pred, digits=3))
    for k, v in metrics.items():
        print(f"{k:12}: {v:.3f}")

    # ---- save ----
    joblib.dump(clf, args.out_dir / "xgb_fusion_model.joblib")
    with open(args.out_dir / "test_metrics.json", "w") as fh:
        json.dump({k: float(v) for k, v in metrics.items()}, fh, indent=2)
    print(f"\n✓ saved model + metrics to {args.out_dir}")

if __name__ == "__main__":
    main()
