#!/usr/bin/env python3
# make_ckpt_yaml.py

import argparse
import os
import importlib
from copy import deepcopy
from dataclasses import is_dataclass, asdict
from pathlib import Path
from typing import Any, List, Tuple

from config import ColbertTrainingArguments, load_config

import yaml


def expand_preconfig_over_checkpoints(preconfig, checkpoints) -> List[Tuple[str, object]]:
    base_output_dir = Path(preconfig["tr_args"]["output_dir"])
    base_run_name = preconfig["tr_args"].get("run_name", "run")

    expanded = []
    for ckpt in checkpoints:
        cfg = deepcopy(preconfig)

        run_name = f"{base_run_name}-{ckpt.name}"

        # 1) point the model to the checkpoint
        cfg["model_args"]["model_name_or_path"] = str(ckpt)

        # 2) unique output_dir
        cfg["tr_args"]["output_dir"] = str(base_output_dir / run_name)

        # 3) unique run_name
        cfg["tr_args"]["run_name"] = run_name

        expanded.append((ckpt.name, cfg))

    return expanded


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("preconfig", type=str,
                    help="Path to the preconfiguration YAML file.")
    ap.add_argument("--ckpt-dir", default=None, type=Path,
                    help="Directory with checkpoints.")
    ap.add_argument("--pattern", default="opt_step-*__merged",
                    help="Glob pattern for checkpoint dirs.")
    ap.add_argument("--out-dir", required=True, type=Path,
                    help="Where to write the YAML configs.")
    args = ap.parse_args()

    preconfig = load_config(args.preconfig, return_dict=True)
    ckpt_dir = args.ckpt_dir or Path(preconfig["model_args"]["model_name_or_path"])
    checkpoints = sorted([p for p in ckpt_dir.glob(args.pattern) if p.is_dir()])
    expanded = expand_preconfig_over_checkpoints(preconfig, checkpoints)
    args.out_dir.mkdir(parents=True, exist_ok=True)

    for ckpt_name, cfg in expanded:
        run_name = cfg["tr_args"].get("run_name", ckpt_name)
        out_path = args.out_dir / f"{run_name}.yaml"
        with out_path.open("w") as f:
            yaml.dump(cfg, f, sort_keys=False)
        print(f"✅ Wrote {out_path}")

    print(f"Done. Generated {len(expanded)} YAML configs.")


if __name__ == "__main__":
    main()