#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
em_loop.py (EM orchestrator, robust save-path handling)

- Alternates E (pseudo-label bins) and M (train head/adapter) steps
- Resume support: --resume_from
- Adapter follow policy: --follow_adapter {auto,always,never}
- Logs each iteration to em_log.jsonl
- Robustly locates saved artifacts under either <out_dir>/ or <out_dir>/<dataset>/<exp_name>/

Usage example is in the docstring at the top.
"""

import os
import sys
import json
import shlex
import argparse
import subprocess
from pathlib import Path
from datetime import datetime

def _abspath(p: str) -> str:
    return str(Path(p).absolute()) if p else p

def run(cmd, env=None):
    print("\n>>>", " ".join(cmd), flush=True)
    r = subprocess.run(cmd, env=env)
    if r.returncode != 0:
        raise SystemExit(f"Command failed with code {r.returncode}: {' '.join(cmd)}")

def file_exists(path: Path) -> bool:
    try:
        return path.exists()
    except Exception:
        return False

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--num_iters", type=int, default=3)
    ap.add_argument("--resume_from", type=int, default=0, help="从该迭代编号开始(>=0)。>0 时自动衔接上一轮 head/adapter")

    # shared model/paths
    ap.add_argument("--base_model", required=True)
    ap.add_argument("--adapter_path", default="", help="LoRA 目录或 adapter_config.json 文件路径")
    ap.add_argument("--head_init", required=True, help="初始 confidence_head.pt")
    ap.add_argument("--target_pkl", required=True, help="目标域 hd_data_em.pkl（E-step 输入）")
    ap.add_argument("--dataset", required=True, help="train.py 的 --dataset（绝对路径或数据集名）")
    ap.add_argument("--model_subdir", required=True, help="如 'llama2-7b'，用于拼出 processed pkl 的落盘路径")
    ap.add_argument("--exp_root", required=True, help="每一轮训练输出目录的根路径")
    ap.add_argument("--exp_name", required=True, help="train.py 的 --exp_name，用于日志区分")

    # 是否跟随 adapter（LoRA）
    ap.add_argument("--follow_adapter", type=str, default="auto",
                    choices=["auto", "always", "never"],
                    help="M-step 若产生 LoRA 产物，E-step 的 --adapter_dir 是否切换到最新 out_dir")

    # E-step 控制（传给 train_calib_em* 脚本）
    ap.add_argument("--temp", type=float, default=0.7)
    ap.add_argument("--bias", type=float, default=0.6)
    ap.add_argument("--shape_mode", type=str, default="rank_push", choices=["none", "hist_eq", "rank_push"])
    ap.add_argument("--rank_alpha", type=float, default=3.5)
    ap.add_argument("--mix_lambda", type=float, default=0.9)
    ap.add_argument("--correctness_field", type=str, default="correctness")

    # M-step (train.py) 控制
    ap.add_argument("--num_train_epochs", type=int, default=3)
    ap.add_argument("--save_strategy", type=str, default="no")
    ap.add_argument("--logging_steps", type=int, default=1)
    ap.add_argument("--weight", type=float, default=1.0)
    ap.add_argument("--report_to", type=str, default="wandb")
    ap.add_argument("--distance", type=str, default="prob_init")
    ap.add_argument("--loss_type", type=str, default="distance_sft")
    ap.add_argument("--add_soft_prompts", type=str, default="false")  # 透传
    ap.add_argument("--debug_mode", type=str, default="false")        # 透传

    # 额外传参（可为空字符串）。例如：--train_extra "--per_device_train_batch_size 4 --learning_rate 1e-4"
    ap.add_argument("--train_extra", type=str, default="")

    # 脚本路径（可相对/绝对）
    ap.add_argument("--train_calib_script", type=str, default="train_calib_em.py")
    ap.add_argument("--train_script", type=str, default="train.py")

    args = ap.parse_args()

    exp_root = Path(args.exp_root).absolute()
    exp_root.mkdir(parents=True, exist_ok=True)
    log_path = exp_root / "em_log.jsonl"

    # 规范化脚本路径
    train_calib_script = _abspath(args.train_calib_script)
    train_script = _abspath(args.train_script)

    if not file_exists(Path(train_calib_script)):
        print(f"[warn] train_calib_script not found at {train_calib_script}. Ensure it's on PATH.", flush=True)
    if not file_exists(Path(train_script)):
        print(f"[warn] train_script not found at {train_script}. Ensure it's on PATH.", flush=True)

    # 初始 head/adapter
    head_path = _abspath(args.head_init)
    adapter_path = _abspath(args.adapter_path) if args.adapter_path else ""

    # 断点续跑：若 resume_from>0，则衔接上一轮 out_dir 的产物
    start_it = max(0, int(args.resume_from))
    if start_it > 0:
        prev = exp_root / f"iter_{start_it-1}"
        prev_head = prev / "confidence_head.pt"
        if prev_head.exists():
            head_path = str(prev_head)
            print(f"[resume] use previous head (flat): {head_path}")
        else:
            prev_nested = prev / args.dataset / args.exp_name / "confidence_head.pt"
            if prev_nested.exists():
                head_path = str(prev_nested)
                print(f"[resume] use previous head (nested): {head_path}")
        if args.follow_adapter in ("auto", "always"):
            prev_ad = prev / "adapter_config.json"
            if not prev_ad.exists():
                prev_ad = prev / args.dataset / args.exp_name / "adapter_config.json"
            prev_ad_model = Path(str(prev_ad).replace("adapter_config.json", "adapter_model.safetensors"))
            # if args.follow_adapter == "always" or (prev_ad.exists() and prev_ad_model.exists()):
            #     adapter_path = str(prev_ad.parent)
            #     print(f"[resume] use previous adapter_dir: {adapter_path}")

    # processed pkl 目录
    dataset_root = Path(args.dataset)
    if dataset_root.exists():
        processed_dir = dataset_root / args.model_subdir
        processed_dir.mkdir(parents=True, exist_ok=True)
    else:
        # 若不是路径，交给 train.py 自己处理；E-step 仍需要一个可写路径
        processed_dir = exp_root / "processed_cache" / args.model_subdir
        processed_dir.mkdir(parents=True, exist_ok=True)

    for it in range(start_it, args.num_iters):
        print(f"\n========== EM Iteration {it} ==========")

        # ---------- E-step: 产出分桶 pkl ----------
        processed_name = f"processed_binmean_em_iter{it}.pkl"
        processed_abs  = str(processed_dir / processed_name)

        cmd_e = [
            sys.executable, train_calib_script,
            "--base_model", args.base_model,
            "--adapter_dir", adapter_path if adapter_path else "",
            "--head_path", head_path,
            "--target_pkl", args.target_pkl,
            "--out_pkl", processed_abs,
            "--prefix", "Q: {question}\nA:",
            # 不传空 --suffix；使用脚本默认
            "--layer_index", "-1",
            "--temp", str(args.temp),
            "--bias", str(args.bias),
            "--shape_mode", args.shape_mode,
            "--rank_alpha", str(args.rank_alpha),
            "--mix_lambda", str(args.mix_lambda),
            "--correctness_field", args.correctness_field,
        ]
        cmd_e = [c for c in cmd_e if c != ""]
        run(cmd_e)

        # ---------- M-step: 训练 & 保存新 head ----------
        # out_dir = exp_root / f"iter_{it}"
        # out_dir.mkdir(parents=True, exist_ok=True)
        out_dir = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/"
        cmd_m = [
            sys.executable, train_script,
            "--model_name_or_path", args.base_model,
            "--distance", args.distance,
            "--loss_type", args.loss_type,
            "--add_soft_prompts", args.add_soft_prompts,
            "--debug_mode", args.debug_mode,
            "--num_train_epochs", str(args.num_train_epochs),
            "--save_strategy", args.save_strategy,
            "--logging_steps", str(args.logging_steps),
            "--weight", str(args.weight),
            "--report_to", args.report_to,
            "--dataset", args.dataset,
            "--processed_data", processed_name,
            "--exp_name", args.exp_name,
            "--output_dir", str(out_dir)
        ]
        if args.train_extra.strip():
            cmd_m += shlex.split(args.train_extra.strip())
        run(cmd_m)

        # 下一轮使用新的 head（尝试平铺和嵌套两种路径）
        # new_head = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/confidence_head.pt"
        # if not new_head.exists():
        #     nested_dir = out_dir / args.dataset / args.exp_name
        #     new_head = nested_dir / "confidence_head.pt"
        # if not new_head.exists():
        #     raise SystemExit(f"[Error] {new_head} not found. 确认 train.py 已在 {out_dir}/ 或 {out_dir}/{args.dataset}/{args.exp_name}/ 保存 confidence_head.pt")
        # head_path = str(new_head)
        head_path = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/confidence_head.pt"

        # 是否切换 adapter_dir（同样兼容嵌套目录）
        switched_adapter = False
        if args.follow_adapter in ("auto", "always"):
            # ad_cfg = out_dir / "adapter_config.json"
            # ad_model = out_dir / "adapter_model.safetensors"
            # nested_dir = out_dir / args.dataset / args.exp_name
            # if not ad_cfg.exists(): ad_cfg = nested_dir / "adapter_config.json"
            # if not ad_model.exists(): ad_model = nested_dir / "adapter_model.safetensors"
            ad_cfg = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/adapter_config.json"
            ad_model = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/adapter_model.safetensors"
            # nested_dir = out_dir / args.dataset / args.exp_name
            # if args.follow_adapter == "always" or (ad_cfg and ad_model):
            #     adapter_path = str(ad_cfg.parent)
            #     switched_adapter = True
            #     print(f"[EM] switch adapter_dir -> {adapter_path}")
            adapter_path = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100/adapter_config.json"
        # 记录本轮元信息
        meta = {
            "iter": it,
            "time": datetime.now().isoformat(timespec="seconds"),
            "processed_pkl": processed_abs,
            "head_path": head_path,
            "adapter_path": adapter_path,
            "switched_adapter": switched_adapter,
            "e_params": {
                "temp": args.temp,
                "bias": args.bias,
                "shape_mode": args.shape_mode,
                "rank_alpha": args.rank_alpha,
                "mix_lambda": args.mix_lambda,
            },
            "m_params": {
                "num_train_epochs": args.num_train_epochs,
                "save_strategy": args.save_strategy,
                "logging_steps": args.logging_steps,
                "weight": args.weight,
                "distance": args.distance,
                "loss_type": args.loss_type,
            },
            "train_extra": args.train_extra,
        }
        with open(exp_root / "em_log.jsonl", "a", encoding="utf-8") as f:
            f.write(json.dumps(meta, ensure_ascii=False) + "\n")

    print("\n[EM] All iterations finished successfully.")
    print("[EM] Latest head:", head_path)
    print("[EM] Latest adapter_dir:", adapter_path or "(unchanged / none)")

if __name__ == "__main__":
    main()
