from pathlib import Path
import argparse
import hashlib
import json
import os
import platform
import random
import shutil
import time

import numpy as np
import pytorch_lightning as pl
import torch as T
import torch.nn as nn
from pytorch_lightning.loggers import TensorBoardLogger

from NCMCounterfactuals.src.ds.causal_graph import CausalGraph
from NCMCounterfactuals.src.pipeline.relational_id_pipeline import RelationalIDPipeline
from NCMCounterfactuals.src.scm.ctm import RelationalCTM
from NCMCounterfactuals.src.scm.ncm.mle_ncm import MLE_NCM
from NCMCounterfactuals.src.scm.nn.gumbel_mlp import GumbelMLP
from NCMCounterfactuals.src.scm.ncm.role_aggregator import build_role_modules
from simple_traffic_id_experiment import (
    acquire_lock,
    make_binary_query,
    read_cg_with_metadata,
    release_lock,
    write_json_atomic,
)


def template_of(node):
    if node.startswith("C") and node.endswith("_X"):
        return "C_X"
    if node.startswith("C") and node.endswith("_Y"):
        return "C_Y"
    return None




def role_template_of(node):
    parts = node.split("_")
    if len(parts) >= 4 and parts[0] == "R":
        return "R_{}_{}".format(parts[2], parts[3])
    return None


def v_size_for_cg(cg, role_aggregators, count_bits=3):
    v_size = {k: 1 for k in cg}
    for node in cg:
        role_key = role_template_of(node)
        if role_key is None:
            continue
        agg_name = role_aggregators.get(role_key)
        if agg_name is not None and str(agg_name).lower() == "count":
            v_size[node] = count_bits
    return v_size


def build_reps(cg):
    reps = {}
    preferred = {
        "C_X": "C0_X",
        "C_Y": "C0_Y",
    }
    for tmpl, node in preferred.items():
        if node in cg:
            reps[tmpl] = node
    for node in cg:
        tmpl = template_of(node)
        if tmpl is None or tmpl in reps:
            continue
        reps[tmpl] = node
    return reps


def build_rctm(cg, role_aggregators, seed, template_funcs=None):
    reps = build_reps(cg)
    return RelationalCTM(
        cg,
        template_of=template_of,
        reps=reps,
        v_size=v_size_for_cg(cg, role_aggregators),
        regions=5,
        c2_scale=1.0,
        role_aggregators=role_aggregators,
        seed=seed,
        template_funcs=template_funcs,
    )


def pad_template_funcs(template_funcs):
    padded = {}
    for tmpl_name, tmpl in template_funcs.items():
        func = tmpl["func"]
        u_keys = list(tmpl.get("u_keys", []))

        def wrapped(v_raw, u_raw, _func=func, _u_keys=u_keys):
            missing = [k for k in _u_keys if k not in u_raw]
            if missing:
                if u_raw:
                    sample = next(iter(u_raw.values()))
                elif v_raw:
                    sample = next(iter(v_raw.values()))
                else:
                    sample = T.zeros((1, 1))
                n = sample.shape[0]
                device = sample.device
                for k in missing:
                    u_raw[k] = T.zeros((n, 1), device=device)
            return _func(v_raw, u_raw)

        padded[tmpl_name] = {
            "func": wrapped,
            "parent_keys": tmpl.get("parent_keys", []),
            "u_keys": u_keys,
            "agg_templates": tmpl.get("agg_templates", []),
        }
    return padded


class SharedModuleWrapperMLE(nn.Module):
    def __init__(self, shared, pa_key_map, u_key_map):
        super().__init__()
        self.shared = shared
        self.pa_key_map = pa_key_map
        self.u_key_map = u_key_map

    def forward(self, pa, u, v=None, n=None):
        device = next(self.shared.parameters()).device
        mapped_pa = {dst: pa[src].to(device) for src, dst in self.pa_key_map.items()}
        mapped_u = {dst: u[src].to(device) for src, dst in self.u_key_map.items()}
        return self.shared(mapped_pa, mapped_u, v=v, n=n)


def build_shared_role_ncms(source_cgs, target_cg, hyperparams, role_aggregators):
    reps = build_reps(target_cg)
    cg_ref = target_cg
    v_size_ref = v_size_for_cg(cg_ref, role_aggregators)
    shared_templates_mod = {}
    for template, rep in reps.items():
        if rep not in cg_ref:
            raise ValueError(f"Representative {rep} not found in reference CG.")
        rep_pa = list(cg_ref.pa[rep])
        rep_u = list(cg_ref.v2c2[rep])
        shared_templates_mod[template] = GumbelMLP(
            {p: v_size_ref[p] for p in rep_pa},
            {u: hyperparams.get("u-size", 1) for u in rep_u},
            1,
            h_layers=hyperparams.get("h-layers", 2),
            h_size=hyperparams.get("h-size", 128),
        )

    def build_wrapped_modules(cg):
        modules = build_role_modules(cg, role_aggregators)
        for node in cg:
            template = template_of(node)
            if template not in reps:
                continue
            rep = reps[template]
            rep_pa = list(cg_ref.pa[rep])
            rep_u = list(cg_ref.v2c2[rep])
            cur_pa = list(cg.pa[node])
            cur_u = list(cg.v2c2[node])
            if len(rep_pa) != len(cur_pa) or len(rep_u) != len(cur_u):
                raise ValueError(
                    f"Incompatible parents/u for template {template}: {node} "
                    f"(pa {len(cur_pa)} vs {len(rep_pa)}, u {len(cur_u)} vs {len(rep_u)})"
                )
            pa_key_map = {cur_pa[i]: rep_pa[i] for i in range(len(cur_pa))}
            u_key_map = {cur_u[i]: rep_u[i] for i in range(len(cur_u))}
            modules[node] = SharedModuleWrapperMLE(
                shared_templates_mod[template], pa_key_map, u_key_map
            )
        return modules

    source_ncms = []
    for cg in source_cgs:
        modules = build_wrapped_modules(cg)
        source_ncms.append(
            MLE_NCM(
                cg,
                v_size=v_size_for_cg(cg, role_aggregators),
                default_u_size=hyperparams.get("u-size", 1),
                f=modules,
                hyperparams=hyperparams,
            )
        )
    target_modules = build_wrapped_modules(target_cg)
    target_ncm = MLE_NCM(
        target_cg,
        v_size=v_size_for_cg(target_cg, role_aggregators),
        default_u_size=hyperparams.get("u-size", 1),
        f=target_modules,
        hyperparams=hyperparams,
    )
    return source_ncms, target_ncm, shared_templates_mod, reps, cg_ref


def build_eval_ncm(target_cg, shared_templates_mod, hyperparams, role_aggregators, reps, cg_ref):
    modules = build_role_modules(target_cg, role_aggregators)
    shared_device = next(iter(shared_templates_mod.values())).parameters().__next__().device
    for node in target_cg:
        template = template_of(node)
        if template not in reps:
            continue
        rep = reps[template]
        rep_pa = list(cg_ref.pa[rep])
        rep_u = list(cg_ref.v2c2[rep])
        cur_pa = list(target_cg.pa[node])
        cur_u = list(target_cg.v2c2[node])
        if len(rep_pa) != len(cur_pa) or len(rep_u) != len(cur_u):
            raise ValueError(
                f"Incompatible parents/u for template {template}: {node} "
                f"(pa {len(cur_pa)} vs {len(rep_pa)}, u {len(cur_u)} vs {len(rep_u)})"
            )
        pa_key_map = {cur_pa[i]: rep_pa[i] for i in range(len(cur_pa))}
        u_key_map = {cur_u[i]: rep_u[i] for i in range(len(cur_u))}
        modules[node] = SharedModuleWrapperMLE(
            shared_templates_mod[template], pa_key_map, u_key_map
        )
    eval_ncm = MLE_NCM(
        target_cg,
        v_size=v_size_for_cg(target_cg, role_aggregators),
        default_u_size=hyperparams.get("u-size", 1),
        f=modules,
        hyperparams=hyperparams,
    )
    eval_ncm.to(shared_device)
    return eval_ncm


class QueryDeltaLogger(pl.Callback):
    def __init__(
        self,
        target_cgs,
        target_ctms,
        queries_by_target,
        hyperparams,
        role_aggregators,
        shared_templates_mod,
        reps,
        cg_ref,
        log_every_epochs=10,
        eval_n=10000,
    ):
        super().__init__()
        self.target_cgs = target_cgs
        self.target_ctms = target_ctms
        self.queries_by_target = queries_by_target
        self.hyperparams = hyperparams
        self.role_aggregators = role_aggregators
        self.shared_templates_mod = shared_templates_mod
        self.reps = reps
        self.cg_ref = cg_ref
        self.log_every_epochs = log_every_epochs
        self.eval_n = eval_n
        self._true_cache = self._build_true_cache()
        self._eval_ncms = {}

    def _build_true_cache(self):
        cache = {}
        for target_name, target_ctm in self.target_ctms.items():
            for var, do_vars in self.queries_by_target.get(target_name, []):
                do_vals = {k: 1 for k in do_vars}
                q = make_binary_query(var, do_vals)
                true_val = target_ctm.compute_ctf(q, n=self.eval_n, evaluating=True)
                cache[(target_name, var, tuple(do_vars))] = float(true_val)
        return cache

    def _log_epoch(self, trainer, pl_module, epoch):
        if trainer.logger is None:
            return
        for target_name, target_cg in self.target_cgs.items():
            eval_ncm = self._eval_ncms.get(target_name)
            if eval_ncm is None:
                eval_ncm = build_eval_ncm(
                    target_cg,
                    shared_templates_mod=self.shared_templates_mod,
                    hyperparams=self.hyperparams,
                    role_aggregators=self.role_aggregators,
                    reps=self.reps,
                    cg_ref=self.cg_ref,
                )
                self._eval_ncms[target_name] = eval_ncm
            for var, do_vars in self.queries_by_target.get(target_name, []):
                do_vals = {k: 1 for k in do_vars}
                q = make_binary_query(var, do_vals)
                with T.no_grad():
                    ncm_val = eval_ncm.compute_ctf(q, n=self.eval_n, evaluating=True)
                true_key = (target_name, var, tuple(do_vars))
                true_val = self._true_cache.get(true_key)
                if true_val is None:
                    continue
                delta = abs(float(ncm_val) - true_val)
                do_tag = "+".join(do_vars)
                tag = f"delta/{target_name}/{var}_do_{do_tag}"
                val_tag = f"query_value/{target_name}/{var}_do_{do_tag}"
                trainer.logger.experiment.add_scalar(tag, delta, epoch)
                trainer.logger.experiment.add_scalar(val_tag, float(ncm_val), epoch)

    def on_train_start(self, trainer, pl_module):
        self._log_epoch(trainer, pl_module, epoch=0)

    def on_train_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch + 1) % self.log_every_epochs != 0:
            return
        epoch = trainer.current_epoch + 1
        self._log_epoch(trainer, pl_module, epoch)


def evaluate_targets(
    label,
    target_cgs,
    target_ctms,
    queries_by_target,
    shared_templates_mod,
    hyperparams,
    role_aggregators,
    reps,
    cg_ref,
):
    results = {}
    print(f"\n=== Target query eval for {label} ===")
    for target_name, target_cg in target_cgs.items():
        target_ncm = build_eval_ncm(
            target_cg,
            shared_templates_mod,
            hyperparams=hyperparams,
            role_aggregators=role_aggregators,
            reps=reps,
            cg_ref=cg_ref,
        )
        target_ctm = target_ctms.get(target_name)
        if target_name not in queries_by_target:
            continue
        results[target_name] = []
        for var, do_vars in queries_by_target[target_name]:
            do_vals = {k: 1 for k in do_vars}
            q = make_binary_query(var, do_vals)
            q_val = target_ncm.compute_ctf(q, n=10000, evaluating=True)
            true_val = None
            if target_ctm is not None:
                true_val = target_ctm.compute_ctf(q, n=10000, evaluating=True)
            do_str = ",".join([f"{k}=1" for k in do_vals.keys()])
            print(
                f"{target_name}: P({var}=1 | do({do_str})) = {q_val} (true={true_val})"
            )
            results[target_name].append(
                {
                    "var": var,
                    "do_vals": do_vals,
                    "value": float(q_val),
                    "true_value": None if true_val is None else float(true_val),
                }
            )
    return results


def build_run_id(
    label,
    source_paths,
    target_path,
    hyperparams,
    trainer_cfg,
    trial_idx,
    seed,
    query_name,
    kind,
    role_agg,
):
    trainer_cfg_hash = {k: v for k, v in trainer_cfg.items() if k not in {"logger", "callbacks"}}
    devices = trainer_cfg_hash.get("devices")
    if isinstance(devices, (list, tuple)):
        trainer_cfg_hash["devices"] = len(devices)
    payload = {
        "label": label,
        "sources": [str(p) for p in source_paths],
        "target": str(target_path),
        "hyperparams": hyperparams,
        "trainer_cfg": trainer_cfg_hash,
        "trial_idx": trial_idx,
        "seed": seed,
        "query": query_name,
        "kind": kind,
        "role_agg": role_agg,
    }
    payload_str = json.dumps(payload, sort_keys=True)
    digest = hashlib.sha256(payload_str.encode("utf-8")).hexdigest()[:10]
    safe_label = label.replace("+", "_")
    safe_role_agg = str(role_agg).replace("+", "_")
    safe_query = query_name.replace("+", "_")
    return f"{safe_label}_role_{safe_role_agg}_{safe_query}_{kind}_trial_{trial_idx}_{digest}"


def lock_path_for(root_name, run_id):
    lock_dir = Path(__file__).resolve().parent / f"{root_name}_locks"
    lock_dir.mkdir(parents=True, exist_ok=True)
    return lock_dir / f"{run_id}.lock"


def results_dir_for(root_name, label, trial_idx):
    results_dir = Path(__file__).resolve().parent / root_name
    safe_label = label.replace("+", "_")
    trial_dir = results_dir / safe_label / f"trial{trial_idx}"
    trial_dir.mkdir(parents=True, exist_ok=True)
    return trial_dir




def run_training(
    label,
    source_paths,
    target_path,
    hyperparams,
    role_aggregators,
    seed,
    trainer_cfg,
    target_cgs,
    target_ctms,
    queries_by_target,
    template_funcs,
    query,
    query_sign,
):
    source_cgs = [read_cg_with_metadata(str(p)) for p in source_paths]
    target_cg = read_cg_with_metadata(str(target_path))

    do_var_list = [{}]
    n_samples = trainer_cfg["n_samples"]
    source_specs = []
    for cg in source_cgs:
        rctm = build_rctm(cg, role_aggregators, seed, template_funcs=template_funcs)
        dat_sets = [rctm(n=n_samples, do=do_set) for do_set in do_var_list]
        source_specs.append({"cg": cg, "dat_sets": dat_sets, "do_var_list": do_var_list})

    (
        source_ncms,
        target_ncm,
        shared_templates,
        reps,
        cg_ref,
    ) = build_shared_role_ncms(
        source_cgs=source_cgs,
        target_cg=target_cg,
        hyperparams=hyperparams,
        role_aggregators=role_aggregators,
    )

    role_vars = {v for cg in source_cgs + [target_cg] for v in cg if v.startswith("R_")}
    pipeline = RelationalIDPipeline(
        source_specs=source_specs,
        target_spec={"cg": target_cg},
        hyperparams=hyperparams,
        source_ncms=source_ncms,
        target_ncm=target_ncm,
        query=query,
        query_sign=query_sign,
        role_vars=role_vars,
    )

    trainer = pl.Trainer(
        max_epochs=trainer_cfg["max_epochs"],
        log_every_n_steps=1,
        accelerator=trainer_cfg["accelerator"],
        devices=trainer_cfg["devices"],
        limit_train_batches=trainer_cfg["limit_train_batches"],
        logger=trainer_cfg.get("logger"),
        callbacks=[
            QueryDeltaLogger(
                target_cgs=target_cgs,
                target_ctms=target_ctms,
                queries_by_target=queries_by_target,
                hyperparams=hyperparams,
                role_aggregators=role_aggregators,
                shared_templates_mod=shared_templates,
                reps=reps,
                cg_ref=cg_ref,
                log_every_epochs=trainer_cfg.get("log_every_epochs", 10),
                eval_n=trainer_cfg["eval_n"],
            )
        ],
    )

    kind = "max" if query_sign == -1 else "min"
    print(f"\n=== Training {label} ({kind}) ===")
    start = time.perf_counter()
    trainer.fit(pipeline)
    duration = time.perf_counter() - start
    print(f"Done {label} ({kind}) in {duration:.2f}s")
    eval_results = evaluate_targets(
        label,
        target_cgs,
        target_ctms,
        queries_by_target,
        shared_templates_mod=shared_templates,
        hyperparams=hyperparams,
        role_aggregators=role_aggregators,
        reps=reps,
        cg_ref=cg_ref,
    )
    return duration, eval_results, pipeline.target_ncm


if __name__ == "__main__":
    RANDOM_SEED = 7
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--graph",
        choices=["bow", "iv"],
        required=True,
        help="Which traffic graph family to run.",
    )
    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
    parser.add_argument("--accelerator", default="gpu", help="Trainer accelerator")
    parser.add_argument("--devices", type=int, default=1, help="Number of devices to use")
    parser.add_argument("--gpu-id", type=int, default=None, help="GPU id to use (overrides --devices)")
    parser.add_argument("--max-epochs", type=int, default=200, help="Max training epochs")
    parser.add_argument("--limit-train-batches", type=float, default=1.0)
    parser.add_argument("--num-trials", type=int, default=1)
    parser.add_argument("--trial-idx", type=int, default=None, help="Run only this trial index (1-based).")
    parser.add_argument("--n-samples", type=int, default=10000)
    parser.add_argument("--eval-n", type=int, default=10000)
    parser.add_argument(
        "--role-agg",
        choices=["strict_maj", "weak_maj", "or", "and", "min", "max", "sum", "mean", "count"],
        default="strict_maj",
        help="Aggregator for all role nodes (count encodes as bits; capped at 5).",
    )
    parser.add_argument(
        "--sweep-lr",
        action="store_true",
        help="Enable learning rate sweep using --lr-sweep values.",
    )
    parser.add_argument(
        "--lr-sweep",
        nargs="+",
        type=float,
        default=[1e-3],
        help="Learning rates to sweep.",
    )
    args = parser.parse_args()
    if not args.sweep_lr:
        args.lr_sweep = [1e-3]

    cg_dir = Path(__file__).resolve().parent / "NCMCounterfactuals" / "dat" / "cg"
    cg_a = cg_dir / f"{args.graph}_trafficA.cg"
    cg_b = cg_dir / f"{args.graph}_trafficB.cg"
    if not cg_a.exists() or not cg_b.exists():
        raise FileNotFoundError(f"Missing CG files for graph '{args.graph}'.")

    role_aggregators = {
        "R_X_Y": args.role_agg,
        "R_X_X": args.role_agg,
    }

    hyperparams = {
        "lr": 1e-3,
        "data-bs": 1000,
        "ncm-bs": 1000,
        "h-layers": 2,
        "h-size": 128,
        "u-size": 1,
        "max-query-iters": 300,
        "min-lambda": 0.001,
        "max-lambda": 1.0,
        "mc-sample-size": 5000,
        "full-batch": False,
    }

    devices = [args.gpu_id] if args.gpu_id is not None else args.devices
    trainer_cfg = {
        "max_epochs": args.max_epochs,
        "limit_train_batches": args.limit_train_batches,
        "accelerator": args.accelerator,
        "devices": devices,
        "n_samples": args.n_samples,
        "eval_n": args.eval_n,
    }

    if args.debug:
        trainer_cfg.update({"max_epochs": 10, "limit_train_batches": 2, "n_samples": 500})
        hyperparams["mc-sample-size"] = 100
        hyperparams["ncm-bs"] = 200
        hyperparams["fast-counts"] = True
        hyperparams["fast-counts-n"] = 200
        hyperparams["profile-counts"] = True
        hyperparams["profile-likelihood"] = True
        hyperparams["profile-every"] = 1

    target_cgs = {"B": read_cg_with_metadata(str(cg_b))}
    queries_by_target = {
        "B": [
            ("C2_Y", ["C2_X"]),
            ("C2_Y", ["C1_X"]),
        ]
    }
    query_specs = [
        ("C2_Y_do_C2_X", ("C2_Y", ["C2_X"])),
        ("C2_Y_do_C1_X", ("C2_Y", ["C1_X"])),
    ]

    root_name = f"{args.graph}_traffic_id_results"
    durations = {}
    if args.trial_idx is not None:
        trial_range = [args.trial_idx]
    else:
        trial_range = range(1, args.num_trials + 1)
    for trial_idx in trial_range:
        trial_seed = RANDOM_SEED + trial_idx - 1
        random.seed(trial_seed)
        np.random.seed(trial_seed)
        T.manual_seed(trial_seed)
        pl.seed_everything(trial_seed, workers=True)

        base_ctm = build_rctm(read_cg_with_metadata(str(cg_a)), role_aggregators, trial_seed)
        shared_template_funcs = pad_template_funcs(base_ctm.template_funcs)
        target_ctms = {
            "B": build_rctm(
                read_cg_with_metadata(str(cg_b)),
                role_aggregators,
                seed=None,
                template_funcs=shared_template_funcs,
            )
        }

        for lr in args.lr_sweep:
            hyperparams["lr"] = lr
            for query_name, (var, do_vars) in query_specs:
                do_vals = {k: 1 for k in do_vars}
                query = make_binary_query(var, do_vals)
                for kind, sign in [("max", -1), ("min", 1)]:
                    label = f"{args.graph}_A_to_B_{query_name}"
                    run_id = build_run_id(
                        label,
                        [cg_a],
                        target_path=cg_b,
                        hyperparams=hyperparams,
                        trainer_cfg=trainer_cfg,
                        trial_idx=trial_idx,
                        seed=trial_seed,
                        query_name=query_name,
                        kind=kind,
                        role_agg=args.role_agg,
                    )
                    lock_path = lock_path_for(root_name, run_id)
                    results_dir = results_dir_for(root_name, label, trial_idx)
                    safe_label = label.replace("+", "_")
                    results_path = (
                        results_dir
                        / f"results_{safe_label}_{query_name}_{kind}__{run_id}.json"
                    )
                    hyperparams_path = (
                        results_dir
                        / f"hyperparams_{safe_label}_{query_name}_{kind}__{run_id}.json"
                    )
                    if results_path.exists():
                        print(
                            f"Skipping {label} trial {trial_idx} lr={lr} {query_name} {kind}: "
                            f"results already exist at {results_path}"
                        )
                        continue
                    if not acquire_lock(lock_path):
                        print(
                            f"Skipping {label} trial {trial_idx} lr={lr} {query_name} {kind}: "
                            f"lock held at {lock_path}"
                        )
                        continue
                    try:
                        safe_role_agg = str(args.role_agg).replace("+", "_")
                        tb_log_dir = (
                            results_dir
                            / "tb"
                            / f"role_{safe_role_agg}"
                            / f"lr_{lr:.0e}"
                            / f"{query_name}_{kind}"
                        )
                        logger = TensorBoardLogger(
                            save_dir=str(tb_log_dir),
                            name="",
                            version="",
                            default_hp_metric=False,
                        )
                        trainer_cfg["logger"] = logger
                        duration, eval_results, target_ncm = run_training(
                            label,
                            [cg_a],
                            target_path=cg_b,
                            hyperparams=hyperparams,
                            role_aggregators=role_aggregators,
                            seed=trial_seed,
                            trainer_cfg=trainer_cfg,
                            target_cgs=target_cgs,
                            target_ctms=target_ctms,
                            queries_by_target=queries_by_target,
                            template_funcs=shared_template_funcs,
                            query=query,
                            query_sign=sign,
                        )
                        durations[f"{label}:trial{trial_idx}:lr{lr}:{query_name}:{kind}"] = duration
                        hp_subset = {
                            "lr": hyperparams.get("lr"),
                            "h-size": hyperparams.get("h-size"),
                            "h-layers": hyperparams.get("h-layers"),
                            "u-size": hyperparams.get("u-size"),
                            "mc-sample-size": hyperparams.get("mc-sample-size"),
                            "ncm-bs": hyperparams.get("ncm-bs"),
                            "data-bs": hyperparams.get("data-bs"),
                        }
                        payload = {
                            "run_id": run_id,
                            "graph": args.graph,
                            "query_name": query_name,
                            "kind": kind,
                            "trial_idx": trial_idx,
                            "seed": trial_seed,
                            "source": str(cg_a),
                            "target": str(cg_b),
                            "hyperparams": hp_subset,
                            "trainer_cfg": {
                                "max_epochs": trainer_cfg.get("max_epochs"),
                                "limit_train_batches": trainer_cfg.get("limit_train_batches"),
                                "accelerator": trainer_cfg.get("accelerator"),
                                "devices": trainer_cfg.get("devices"),
                                "n_samples": trainer_cfg.get("n_samples"),
                                "eval_n": trainer_cfg.get("eval_n"),
                            },
                            "queries_by_target": queries_by_target,
                            "duration_sec": duration,
                            "eval_results": eval_results,
                            "role_agg": args.role_agg,
                            "lr": lr,
                        }
                        write_json_atomic(results_path, payload)
                        write_json_atomic(hyperparams_path, hp_subset)
                    finally:
                        release_lock(lock_path)
