import argparse
import logging
from .utils.io import read_yaml, _deep_merge
from .runners.run_ablation import run_ablation
from .runners.run_poison import run_poison_experiments
from .runners.run_kfold_cv_compare import run_kfold_cv_compare
from .utils.logging import setup_logging
from .utils.build import build_from_cfg

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--mode", choices=["ablation", "poison", "kfold"], required=True)
    ap.add_argument("--logdir", default=None)
    ap.add_argument("--device", default=None)
    args = ap.parse_args()
    
    
    cfg = read_yaml(args.config)
    logger = setup_logging(log_dir=args.logdir, level="DEBUG")

    # If a matrix of experiments is defined, iterate through each entry
    base_cfg = {k: v for k, v in cfg.items() if k != "matrix"}
    results = {}
    for i, exp in enumerate(cfg["matrix"]):
        exp_name = exp.get("name", f"exp_{i}") if isinstance(exp, dict) else f"exp_{i}"
        merged = _deep_merge(base_cfg, exp if isinstance(exp, dict) else {})
        logger.info("Running experiment %d/%d: %s", i + 1, len(cfg["matrix"]), exp_name)
        if args.mode == "poison":
            if not ("model" in merged and "dataset" in merged and "loss_fn" in merged):
                build = build_from_cfg(merged, device=args.device)
                model, _, dataset, _, loss_fn, _, labels = build
                merged = {**merged, "model": model, "dataset": dataset, "loss_fn": loss_fn, "labels": labels}
            out = run_poison_experiments(merged, device=args.device)
        elif args.mode == "kfold":
            if not ("model" in merged and "dataset" in merged and "loss_fn" in merged):
                build = build_from_cfg(merged, device=args.device)
                model, _, dataset, _, loss_fn, _, labels = build
                merged = {**merged, "model": model, "dataset": dataset, "loss_fn": loss_fn, "labels": labels}
            out = run_kfold_cv_compare(merged, device=args.device)
        else:  # ablation
            if not ("model" in merged and "dataset" in merged and "loss_fn" in merged):
                model, model_ctor, dataset, dataset_ctor, loss_fn,loss_fn_ctor, labels = build_from_cfg(merged, device=args.device)
                merged = {**merged, "model": model_ctor(), "model_ctor": model_ctor, "dataset": dataset,"dataset_ctor": dataset_ctor, "loss_fn": loss_fn, "loss_fn_ctor": loss_fn_ctor,  "labels": labels}
            out = run_ablation(merged)
        results[exp_name] = out
    try:
        all_keys = {name: list(res.keys()) for name, res in results.items()}
        logger.info("Completed %d experiments. Result keys by name: %s", len(results), all_keys)
    except Exception:
        logger.info("Completed %d experiments.", len(results))

    logger.info("Completed. Keys: %s", list(out.keys()))

if __name__ == "__main__":
    main()
