# -*- coding: utf-8 -*-
"""
ICLR unified runner.
Runs Mode 1 (once) and Modes 2-4 (per model) over selected models,
and aggregates outputs into CSVs and a single Excel workbook.
"""

import os
import argparse
from pathlib import Path

from models import ExperimentConfig, get_models
from data_prep import (
    ensure_dirs as ensure_data_dirs,
    create_or_load_mode1,
    create_or_load_mode2,
    create_or_load_mode3,
    create_or_load_mode4,
)
from mode1 import run_mode1
from mode2_dpo import run_mode2_for_model, aggregate_mode2
from mode3 import run_mode3_for_model, aggregate_mode3
from mode4 import run_mode4_for_model, aggregate_mode4
from utils import ensure_dirs, to_excel


def main(args=None):
    cfg = ExperimentConfig()

    # existing switches
    if args and args.fast:
        cfg.fast_mode = True
    if args and args.out:
        cfg.out_root = args.out
    if args and args.models:
        cfg.model_ids = args.models

    # NEW: feature flags
    real_dpo = bool(getattr(args, "real_dpo", False))            # True -> run real DPO (longer)
    use_semantic = bool(getattr(args, "use_semantic", False))    # True -> enable semantic matching in Mode 3

    # ---- output dirs
    out_root = Path(cfg.out_root)
    mode1_dir = out_root / "mode1_sanitization"
    mode2_dir = out_root / "mode2_alignment"
    mode3_dir = out_root / "mode3_guards"
    mode4_dir = out_root / "mode4_redteam"
    figures_dir = out_root / "figures"
    ensure_dirs(mode1_dir, mode2_dir, mode3_dir, mode4_dir, figures_dir)

    # ---- prepare data
    data_dirs = ensure_data_dirs(base=str(out_root / "data"))
    corpus_csv = create_or_load_mode1(data_dirs, n_rows=800 if not cfg.fast_mode else 200)
    dpo_pairs  = create_or_load_mode2(data_dirs, n_pairs=120 if not cfg.fast_mode else 60)
    guard_csv  = create_or_load_mode3(data_dirs)
    redteam_csv= create_or_load_mode4(data_dirs)

    # ---- Mode 1 (global once)
    run_mode1(corpus_csv, str(mode1_dir))

    # ---- Modes 2–4 per model
    mlist = get_models(cfg)
    mode2_rows = []
    mode3_all  = []
    mode4_rows = []

    for i, m in enumerate(mlist):
        # Mode 2: DPO+LoRA (real_dpo 控制是否走更完整训练；Windows 原生建议 use_quantization=False)
        m2 = run_mode2_for_model(
            m, dpo_pairs_path=dpo_pairs,
            out_dir=str(mode2_dir / m.model_id),
            use_quantization=False,
            fast_mode=(cfg.fast_mode and not real_dpo),  # real_dpo=True -> fast_mode=False
            seed=41 + i
        )
        mode2_rows.append(m2)

        # Mode 3: guard evaluation (可选开启语义匹配)
        m3_df = run_mode3_for_model(
            model_id=m.model_id, guard_levels=cfg.guard_levels,
            prompts_csv=guard_csv, out_dir=str(mode3_dir / m.model_id),
            use_semantic=use_semantic
        )
        mode3_all.append(m3_df)

        # Mode 4: red team (模拟/统计；可在 mode4.py 内提高 n_trials 以获得更稳定 CI)
        m4 = run_mode4_for_model(
            model_id=m.model_id, redteam_csv=redteam_csv,
            out_dir=str(mode4_dir / m.model_id), seed=123 + i
        )
        mode4_rows.append(m4)

    # ---- aggregate per-mode CSVs
    mode2_csv = aggregate_mode2(mode2_rows, str(mode2_dir))
    mode3_csv = aggregate_mode3(mode3_all, str(mode3_dir))
    mode4_csv = aggregate_mode4(mode4_rows, str(mode4_dir))

    # ---- Excel workbook (Mode 2–4), and optionally include multi-dataset Mode 1 combined summary
    try:
        excel_path = out_root / "iclr_mode2_3_4_summary.xlsx"

        sheets = {
            "Mode2_alignment": mode2_csv,
            "Mode3_guard": mode3_csv,
            "Mode4_redteam": mode4_csv,
        }
        # 如果你先运行了“多数据集 Mode 1”脚本并生成合并表，则自动添加到 Excel
        combined_mode1 = out_root / "results" / "mode1_combined_summary.csv"
        if combined_mode1.exists():
            sheets["Mode1_multi_dataset"] = str(combined_mode1)

        to_excel(sheets, str(excel_path))
        print(f"Excel workbook: {excel_path}")
    except Exception as ex:
        print(f"[WARN] Excel export failed: {ex}. CSVs are available instead.")

    # ---- summary
    print("\n=== ICLR Unified Runner Complete ===")
    print(f"Mode 1 summary: {mode1_dir/'mode1_sanitization_results.csv'}")
    print(f"Mode 2 summary: {mode2_csv}")
    print(f"Mode 3 summary: {mode3_csv}")
    print(f"Mode 4 summary: {mode4_csv}")
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fast", action="store_true",
                        help="Use smaller demo datasets / fewer steps.")
    parser.add_argument("--out", type=str, default=None,
                        help="Output root directory (default: ./iclr_results)")
    parser.add_argument("--models", nargs="+",
                        help="Override model list (IDs defined in models.py)")

    # NEW switches
    parser.add_argument("--real-dpo", action="store_true",
                        help="Run real DPO training (longer). Default uses fast/demo path.")
    parser.add_argument("--use-semantic", action="store_true",
                        help="Enable semantic similarity in Mode 3 guard evaluation.")

    args = parser.parse_args()
    raise SystemExit(main(args))
