from pathlib import Path
import argparse
import hashlib
import json
import os
import platform
import random
import shutil
import time

import numpy as np
import pytorch_lightning as pl
import torch as T
import torch.nn as nn
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard import SummaryWriter

from NCMCounterfactuals.src.ds.causal_graph import CausalGraph
from NCMCounterfactuals.src.pipeline.relational_id_pipeline import RelationalIDPipeline
from NCMCounterfactuals.src.pipeline.mle_pipeline import MLEPipeline
from NCMCounterfactuals.src.scm.ctm import RelationalCTM
from NCMCounterfactuals.src.scm.ncm.mle_ncm import MLE_NCM
from NCMCounterfactuals.src.scm.nn.gumbel_mlp import GumbelMLP
from NCMCounterfactuals.src.scm.ncm.role_aggregator import build_role_modules


def template_of(node):
    if node.startswith("S") and node.endswith("_W"):
        return "S_W"
    if node.startswith("P") and node.endswith("_X"):
        return "P_X"
    if node.startswith("P") and node.endswith("_Y"):
        return "P_Y"
    if node.startswith("C") and node.endswith("_B"):
        return "C_B"
    return None


def make_binary_query(var, do_vals):
    from NCMCounterfactuals.src.ds.counterfactual import CTF, CTFTerm

    return CTF({CTFTerm(vars={var}, do_vals=do_vals, var_vals={var: 1})})


def read_cg_with_metadata(cg_path):
    mode = None
    nodes = []
    directed_edges = []
    bidirected_edges = []
    with open(cg_path) as handle:
        for raw_line in handle:
            line = raw_line.strip()
            if not line:
                continue
            if line.startswith("<") and line.endswith(">"):
                mode = line.strip("<>")
                continue
            if mode == "NODES":
                if line.isidentifier():
                    nodes.append(line)
                else:
                    raise ValueError(f"invalid identifier: {line}")
            elif mode == "EDGES":
                if "<->" in line:
                    v1, v2 = map(str.strip, line.split("<->"))
                    bidirected_edges.append((v1, v2))
                elif "->" in line:
                    v1, v2 = map(str.strip, line.split("->"))
                    directed_edges.append((v1, v2))
                else:
                    raise ValueError(f"invalid edge type: {line}")
            else:
                continue
    return CausalGraph(nodes, directed_edges=directed_edges, bidirected_edges=bidirected_edges)

def role_template_of(node):
    parts = node.split("_")
    if len(parts) >= 4 and parts[0] == "R":
        return "R_{}_{}".format(parts[2], parts[3])
    return None

def v_size_for_cg(cg, role_aggregators, count_bits=3):
    v_size = {k: 1 for k in cg}
    for node in cg:
        role_key = role_template_of(node)
        if role_key is None:
            continue
        agg_name = role_aggregators.get(role_key)
        if agg_name is not None and str(agg_name).lower() == "count":
            v_size[node] = count_bits
    return v_size


def build_rctm(cg, role_aggregators, seed, template_funcs=None):
    reps = {
        "S_W": "S0_W",
        "P_X": "P0_X",
        "P_Y": "P0_Y",
        "C_B": "C0_B",
    }
    return RelationalCTM(
        cg,
        template_of=template_of,
        reps=reps,
        v_size=v_size_for_cg(cg, role_aggregators),
        regions=5,
        c2_scale=1.0,
        role_aggregators=role_aggregators,
        seed=seed,
        template_funcs=template_funcs,
    )


class SharedModuleWrapperMLE(nn.Module):
    def __init__(self, shared, pa_key_map, u_key_map):
        super().__init__()
        self.shared = shared
        self.pa_key_map = pa_key_map
        self.u_key_map = u_key_map

    def forward(self, pa, u, v=None, n=None):
        device = next(self.shared.parameters()).device
        mapped_pa = {dst: pa[src].to(device) for src, dst in self.pa_key_map.items()}
        mapped_u = {dst: u[src].to(device) for src, dst in self.u_key_map.items()}
        return self.shared(mapped_pa, mapped_u, v=v, n=n)


def build_shared_role_ncms(source_cgs, target_cg, hyperparams, role_aggregators):
    reps = {
        "S_W": "S0_W",
        "P_X": "P0_X",
        "P_Y": "P0_Y",
        "C_B": "C0_B",
    }
    cg_ref = source_cgs[0] if source_cgs else target_cg
    v_size_ref = v_size_for_cg(cg_ref, role_aggregators)
    shared_templates_mod = {}
    for template, rep in reps.items():
        if rep not in cg_ref:
            raise ValueError(f"Representative {rep} not found in reference CG.")
        rep_pa = list(cg_ref.pa[rep])
        rep_u = list(cg_ref.v2c2[rep])
        shared_templates_mod[template] = GumbelMLP(
            {p: v_size_ref[p] for p in rep_pa},
            {u: hyperparams.get("u-size", 1) for u in rep_u},
            1,
            h_layers=hyperparams.get("h-layers", 2),
            h_size=hyperparams.get("h-size", 128),
        )

    def build_wrapped_modules(cg):
        modules = build_role_modules(cg, role_aggregators)
        for node in cg:
            template = template_of(node)
            if template not in reps:
                continue
            rep = reps[template]
            rep_pa = list(cg_ref.pa[rep])
            rep_u = list(cg_ref.v2c2[rep])
            cur_pa = list(cg.pa[node])
            cur_u = list(cg.v2c2[node])
            if len(rep_pa) != len(cur_pa) or len(rep_u) != len(cur_u):
                raise ValueError(
                    f"Incompatible parents/u for template {template}: {node} "
                    f"(pa {len(cur_pa)} vs {len(rep_pa)}, u {len(cur_u)} vs {len(rep_u)})"
                )
            pa_key_map = {cur_pa[i]: rep_pa[i] for i in range(len(cur_pa))}
            u_key_map = {cur_u[i]: rep_u[i] for i in range(len(cur_u))}
            modules[node] = SharedModuleWrapperMLE(
                shared_templates_mod[template], pa_key_map, u_key_map
            )
        return modules

    source_ncms = []
    for cg in source_cgs:
        modules = build_wrapped_modules(cg)
        source_ncms.append(
            MLE_NCM(
                cg,
                v_size=v_size_for_cg(cg, role_aggregators),
                default_u_size=hyperparams.get("u-size", 1),
                f=modules,
                hyperparams=hyperparams,
            )
        )
    target_modules = build_wrapped_modules(target_cg)
    target_ncm = MLE_NCM(
        target_cg,
        v_size=v_size_for_cg(target_cg, role_aggregators),
        default_u_size=hyperparams.get("u-size", 1),
        f=target_modules,
        hyperparams=hyperparams,
    )
    return source_ncms, target_ncm, shared_templates_mod, reps, cg_ref


def build_source_specs(
    source_cgs, role_aggregators, seed, n_samples, do_var_list, template_funcs=None
):
    source_specs = []
    for cg in source_cgs:
        rctm = build_rctm(cg, role_aggregators, seed, template_funcs=template_funcs)
        dat_sets = [rctm(n=n_samples, do=do_set) for do_set in do_var_list]
        source_specs.append({"cg": cg, "dat_sets": dat_sets, "do_var_list": do_var_list})
    return source_specs


def build_eval_ncm(target_cg, shared_templates_mod, hyperparams, role_aggregators, reps, cg_ref):
    modules = build_role_modules(target_cg, role_aggregators)
    shared_device = next(iter(shared_templates_mod.values())).parameters().__next__().device
    for node in target_cg:
        template = template_of(node)
        if template not in reps:
            continue
        rep = reps[template]
        rep_pa = list(cg_ref.pa[rep])
        rep_u = list(cg_ref.v2c2[rep])
        cur_pa = list(target_cg.pa[node])
        cur_u = list(target_cg.v2c2[node])
        if len(rep_pa) != len(cur_pa) or len(rep_u) != len(cur_u):
            raise ValueError(
                f"Incompatible parents/u for template {template}: {node} "
                f"(pa {len(cur_pa)} vs {len(rep_pa)}, u {len(cur_u)} vs {len(rep_u)})"
            )
        pa_key_map = {cur_pa[i]: rep_pa[i] for i in range(len(cur_pa))}
        u_key_map = {cur_u[i]: rep_u[i] for i in range(len(cur_u))}
        modules[node] = SharedModuleWrapperMLE(
            shared_templates_mod[template], pa_key_map, u_key_map
        )
    eval_ncm = MLE_NCM(
        target_cg,
        v_size=v_size_for_cg(target_cg, role_aggregators),
        default_u_size=hyperparams.get("u-size", 1),
        f=modules,
        hyperparams=hyperparams,
    )
    eval_ncm.to(shared_device)
    return eval_ncm


class QueryDeltaLogger(pl.Callback):
    def __init__(
        self,
        target_cgs,
        target_ctms,
        queries_by_target,
        hyperparams,
        role_aggregators,
        shared_templates_mod,
        reps,
        cg_ref,
        log_every_epochs=10,
        eval_n=10000,
        agg_state_path=None,
        agg_tb_log_dir=None,
    ):
        super().__init__()
        self.target_cgs = target_cgs
        self.target_ctms = target_ctms
        self.queries_by_target = queries_by_target
        self.hyperparams = hyperparams
        self.role_aggregators = role_aggregators
        self.shared_templates_mod = shared_templates_mod
        self.reps = reps
        self.cg_ref = cg_ref
        self.log_every_epochs = log_every_epochs
        self.eval_n = eval_n
        self.agg_state_path = Path(agg_state_path) if agg_state_path else None
        self.agg_tb_log_dir = Path(agg_tb_log_dir) if agg_tb_log_dir else None
        self._agg_writer = None
        self._true_cache = self._build_true_cache()
        if self.agg_state_path is not None:
            self.agg_state_path.parent.mkdir(parents=True, exist_ok=True)
        if self.agg_tb_log_dir is not None:
            self.agg_tb_log_dir.mkdir(parents=True, exist_ok=True)
            self._agg_writer = SummaryWriter(log_dir=str(self.agg_tb_log_dir))

    def _load_agg_state(self):
        if self.agg_state_path is None or not self.agg_state_path.exists():
            return {"epochs": {}}
        try:
            with open(self.agg_state_path) as handle:
                return json.load(handle)
        except Exception:
            return {"epochs": {}}

    def _build_true_cache(self):
        cache = {}
        for target_name, target_ctm in self.target_ctms.items():
            for var, do_vars in self.queries_by_target.get(target_name, []):
                do_vals = {k: 1 for k in do_vars}
                q = make_binary_query(var, do_vals)
                true_val = target_ctm.compute_ctf(q, n=self.eval_n, evaluating=True)
                cache[(target_name, var, tuple(do_vars))] = float(true_val)
        return cache

    def _log_epoch(self, trainer, pl_module, epoch):
        if trainer.logger is None:
            return
        mse_by_query = {}
        for target_name, target_cg in self.target_cgs.items():
            eval_ncm = build_eval_ncm(
                target_cg,
                shared_templates_mod=self.shared_templates_mod,
                hyperparams=self.hyperparams,
                role_aggregators=self.role_aggregators,
                reps=self.reps,
                cg_ref=self.cg_ref,
            )
            for var, do_vars in self.queries_by_target.get(target_name, []):
                do_vals = {k: 1 for k in do_vars}
                q = make_binary_query(var, do_vals)
                ncm_val = eval_ncm.compute_ctf(q, n=self.eval_n, evaluating=True)
                true_key = (target_name, var, tuple(do_vars))
                true_val = self._true_cache.get(true_key)
                if true_val is None:
                    continue
                delta = abs(float(ncm_val) - true_val)
                squared_delta = float(ncm_val) - true_val
                squared_delta = squared_delta * squared_delta
                do_tag = "+".join(do_vars)
                tag = f"delta/{target_name}/{var}_do_{do_tag}"
                trainer.logger.experiment.add_scalar(tag, delta, epoch)
                mse_key = f"{target_name}/{var}_do_{do_tag}"
                mse_by_query[mse_key] = squared_delta

        if self.agg_state_path is None or not mse_by_query:
            return

        state = self._load_agg_state()
        epochs = state.setdefault("epochs", {})
        epoch_key = str(epoch)
        epoch_state = epochs.get(epoch_key, {"count": 0, "means": {}})
        count = epoch_state.get("count", 0)
        means = epoch_state.get("means", {})
        for key, value in mse_by_query.items():
            prev = means.get(key, value)
            if count == 0:
                means[key] = value
            else:
                means[key] = prev + (value - prev) / (count + 1)
        epoch_state["count"] = count + 1
        epoch_state["means"] = means
        epochs[epoch_key] = epoch_state
        write_json_atomic(self.agg_state_path, state)
        if self._agg_writer is not None:
            for key, mean_val in means.items():
                self._agg_writer.add_scalar(f"mse_mean/{key}", mean_val, epoch)

    def on_train_end(self, trainer, pl_module):
        if self._agg_writer is not None:
            self._agg_writer.flush()
            self._agg_writer.close()

    def on_train_start(self, trainer, pl_module):
        self._log_epoch(trainer, pl_module, epoch=0)

    def on_train_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch + 1) % self.log_every_epochs != 0:
            return
        epoch = trainer.current_epoch + 1
        self._log_epoch(trainer, pl_module, epoch)


class MLEQueryDeltaLogger(pl.Callback):
    def __init__(self, target_name, target_ctm, queries_by_target, eval_n=10000, log_every_epochs=10):
        super().__init__()
        self.target_name = target_name
        self.target_ctm = target_ctm
        self.queries_by_target = queries_by_target
        self.eval_n = eval_n
        self.log_every_epochs = log_every_epochs
        self._true_cache = self._build_true_cache()

    def _build_true_cache(self):
        cache = {}
        for var, do_vars in self.queries_by_target.get(self.target_name, []):
            do_vals = {k: 1 for k in do_vars}
            q = make_binary_query(var, do_vals)
            true_val = self.target_ctm.compute_ctf(q, n=self.eval_n, evaluating=True)
            cache[(var, tuple(do_vars))] = float(true_val)
        return cache

    def _log_epoch(self, trainer, pl_module, epoch):
        if trainer.logger is None:
            return
        ncm = pl_module.ncm
        for var, do_vars in self.queries_by_target.get(self.target_name, []):
            do_vals = {k: 1 for k in do_vars}
            q = make_binary_query(var, do_vals)
            with T.no_grad():
                ncm_val = ncm.compute_ctf(q, n=self.eval_n, evaluating=True)
            true_val = self._true_cache.get((var, tuple(do_vars)))
            if true_val is None:
                continue
            delta = abs(float(ncm_val) - true_val)
            do_tag = "+".join(do_vars)
            tag = f"delta/{self.target_name}/{var}_do_{do_tag}"
            trainer.logger.experiment.add_scalar(tag, delta, epoch)

    def on_train_start(self, trainer, pl_module):
        self._log_epoch(trainer, pl_module, epoch=0)

    def on_train_epoch_end(self, trainer, pl_module):
        if (trainer.current_epoch + 1) % self.log_every_epochs != 0:
            return
        epoch = trainer.current_epoch + 1
        self._log_epoch(trainer, pl_module, epoch)


def evaluate_targets(
    label,
    target_cgs,
    target_ctms,
    queries_by_target,
    shared_templates_mod,
    hyperparams,
    role_aggregators,
    reps,
    cg_ref,
):
    results = {}
    print(f"\n=== Target query eval for {label} ===")
    for target_name, target_cg in target_cgs.items():
        target_ncm = build_eval_ncm(
            target_cg,
            shared_templates_mod,
            hyperparams=hyperparams,
            role_aggregators=role_aggregators,
            reps=reps,
            cg_ref=cg_ref,
        )
        target_ctm = target_ctms.get(target_name)
        if target_name not in queries_by_target:
            continue
        results[target_name] = []
        for var, do_vars in queries_by_target[target_name]:
            do_vals = {k: 1 for k in do_vars}
            q = make_binary_query(var, do_vals)
            q_val = target_ncm.compute_ctf(q, n=10000, evaluating=True)
            true_val = None
            if target_ctm is not None:
                true_val = target_ctm.compute_ctf(q, n=10000, evaluating=True)
            do_str = ",".join([f"{k}=1" for k in do_vals.keys()])
            print(
                f"{target_name}: P({var}=1 | do({do_str})) = {q_val} (true={true_val})"
            )
            results[target_name].append(
                {
                    "var": var,
                    "do_vals": do_vals,
                    "value": float(q_val),
                    "true_value": None if true_val is None else float(true_val),
                }
            )
    return results


def evaluate_targets_mle(target_name, target_ncm, target_ctm, queries_by_target, eval_n):
    results = {target_name: []}
    for var, do_vars in queries_by_target.get(target_name, []):
        do_vals = {k: 1 for k in do_vars}
        q = make_binary_query(var, do_vals)
        with T.no_grad():
            q_val = target_ncm.compute_ctf(q, n=eval_n, evaluating=True)
        true_val = target_ctm.compute_ctf(q, n=eval_n, evaluating=True)
        results[target_name].append(
            {
                "var": var,
                "do_vals": do_vals,
                "value": float(q_val),
                "true_value": float(true_val),
            }
        )
    return results


class ProjectedCTM:
    def __init__(self, base_ctm, allowed_keys):
        self.base_ctm = base_ctm
        self.allowed_keys = set(allowed_keys)

    def __call__(self, *args, **kwargs):
        dat = self.base_ctm(*args, **kwargs)
        return {k: v for k, v in dat.items() if k in self.allowed_keys}

    def compute_ctf(self, *args, **kwargs):
        return self.base_ctm.compute_ctf(*args, **kwargs)


def filter_dataset(dat, allowed_keys):
    return {k: v for k, v in dat.items() if k in allowed_keys}


def run_mle_baseline(
    label,
    target_name,
    baseline_cg,
    baseline_generator,
    eval_ctm,
    hyperparams,
    trainer_cfg,
    queries_by_target,
    dat_sets,
    do_var_list,
):
    baseline_hparams = dict(hyperparams)
    baseline_hparams["do-var-list"] = do_var_list
    baseline_hparams["full-batch"] = baseline_hparams.get("full-batch", False)
    eval_query_var, eval_query_do = queries_by_target[target_name][0]
    baseline_hparams["eval-query"] = [
        (make_binary_query(eval_query_var, {k: 1 for k in eval_query_do}), 1)
    ]

    pipeline = MLEPipeline(
        generator=baseline_generator,
        do_var_list=do_var_list,
        dat_sets=dat_sets,
        cg=baseline_cg,
        dim=1,
        hyperparams=baseline_hparams,
    )

    trainer = pl.Trainer(
        max_epochs=trainer_cfg["max_epochs"],
        log_every_n_steps=1,
        accelerator=trainer_cfg["accelerator"],
        devices=trainer_cfg["devices"],
        limit_train_batches=trainer_cfg["limit_train_batches"],
        logger=trainer_cfg.get("logger"),
        callbacks=[
            MLEQueryDeltaLogger(
                target_name=target_name,
                target_ctm=eval_ctm,
                queries_by_target=queries_by_target,
                eval_n=trainer_cfg["eval_n"],
                log_every_epochs=trainer_cfg.get("log_every_epochs", 10),
            )
        ],
    )

    print(f"\n=== Training target-only baseline {label} ===")
    start = time.perf_counter()
    trainer.fit(pipeline)
    duration = time.perf_counter() - start
    print(f"Done target-only baseline {label} in {duration:.2f}s")
    eval_results = evaluate_targets_mle(
        target_name,
        pipeline.ncm,
        eval_ctm,
        queries_by_target,
        trainer_cfg["eval_n"],
    )
    return duration, eval_results


def build_run_id(
    label,
    source_paths,
    target_path,
    hyperparams,
    trainer_cfg,
    trial_idx,
    seed,
    queries_by_target,
    role_agg,
):
    trainer_cfg_hash = {k: v for k, v in trainer_cfg.items() if k not in {"logger", "callbacks"}}
    devices = trainer_cfg_hash.get("devices")
    if isinstance(devices, (list, tuple)):
        trainer_cfg_hash["devices"] = len(devices)
    payload = {
        "label": label,
        "sources": [str(p) for p in source_paths],
        "target": str(target_path),
        "hyperparams": hyperparams,
        "trainer_cfg": trainer_cfg_hash,
        "trial_idx": trial_idx,
        "seed": seed,
        "queries": queries_by_target,
        "role_agg": role_agg,
    }
    payload_str = json.dumps(payload, sort_keys=True)
    digest = hashlib.sha256(payload_str.encode("utf-8")).hexdigest()[:10]
    safe_label = label.replace("+", "_")
    safe_role_agg = str(role_agg).replace("+", "_")
    return f"{safe_label}_role_{safe_role_agg}_trial_{trial_idx}_{digest}"


def lock_path_for(run_id):
    lock_dir = Path(__file__).resolve().parent / "less_simple_traffic_id_locks"
    lock_dir.mkdir(parents=True, exist_ok=True)
    return lock_dir / f"{run_id}.lock"


def results_dir_for(label, trial_idx):
    results_dir = Path(__file__).resolve().parent / "less_simple_traffic_id_results"
    safe_label = label.replace("+", "_")
    trial_dir = results_dir / safe_label / f"trial{trial_idx}"
    trial_dir.mkdir(parents=True, exist_ok=True)
    return trial_dir


def acquire_lock(lock_path):
    payload = {
        "pid": os.getpid(),
        "host": platform.node(),
        "started_at": time.time(),
    }
    flags = os.O_CREAT | os.O_EXCL | os.O_WRONLY
    try:
        fd = os.open(str(lock_path), flags)
    except FileExistsError:
        try:
            with open(lock_path) as handle:
                data = json.load(handle)
            pid = data.get("pid")
            if pid is not None:
                try:
                    os.kill(pid, 0)
                    return False
                except OSError:
                    pass
        except Exception:
            pass
        try:
            lock_path.unlink()
        except FileNotFoundError:
            pass
        fd = os.open(str(lock_path), flags)
    with os.fdopen(fd, "w") as handle:
        json.dump(payload, handle)
    return True


def release_lock(lock_path):
    try:
        lock_path.unlink()
    except FileNotFoundError:
        pass


def write_json_atomic(path, payload):
    tmp_dir = Path("/tmp")
    tmp_dir.mkdir(parents=True, exist_ok=True)
    tmp_path = tmp_dir / f"{path.name}.{os.getpid()}.tmp"
    with open(tmp_path, "w") as handle:
        json.dump(payload, handle, indent=2, sort_keys=True)
    shutil.move(str(tmp_path), str(path))


def run_training(
    label,
    source_paths,
    target_path,
    hyperparams,
    role_aggregators,
    seed,
    trainer_cfg,
    target_cgs,
    target_ctms,
    queries_by_target,
    template_funcs,
    agg_state_path=None,
    agg_tb_log_dir=None,
):
    source_cgs = [read_cg_with_metadata(str(p)) for p in source_paths]
    target_cg = read_cg_with_metadata(str(target_path))

    do_var_list = [{}]
    n_samples = trainer_cfg["n_samples"]
    source_specs = build_source_specs(
        source_cgs,
        role_aggregators,
        seed,
        n_samples,
        do_var_list,
        template_funcs=template_funcs,
    )
    target_spec = {"cg": target_cg}

    (
        source_ncms,
        target_ncm,
        shared_templates_mod,
        reps,
        cg_ref,
    ) = build_shared_role_ncms(
        source_cgs=source_cgs,
        target_cg=target_cg,
        hyperparams=hyperparams,
        role_aggregators=role_aggregators,
    )

    role_vars = {v for cg in source_cgs + [target_cg] for v in cg if v.startswith("R_")}
    pipeline = RelationalIDPipeline(
        source_specs=source_specs,
        target_spec=target_spec,
        hyperparams=hyperparams,
        source_ncms=source_ncms,
        target_ncm=target_ncm,
        query=None,
        query_sign=-1,
        role_vars=role_vars,
    )

    trainer = pl.Trainer(
        max_epochs=trainer_cfg["max_epochs"],
        log_every_n_steps=1,
        accelerator=trainer_cfg["accelerator"],
        devices=trainer_cfg["devices"],
        limit_train_batches=trainer_cfg["limit_train_batches"],
        logger=trainer_cfg.get("logger"),
        callbacks=[
            QueryDeltaLogger(
                target_cgs=target_cgs,
                target_ctms=target_ctms,
                queries_by_target=queries_by_target,
                hyperparams=hyperparams,
                role_aggregators=role_aggregators,
                shared_templates_mod=shared_templates_mod,
                reps=reps,
                cg_ref=cg_ref,
                log_every_epochs=trainer_cfg.get("log_every_epochs", 10),
                eval_n=trainer_cfg["eval_n"],
                agg_state_path=agg_state_path,
                agg_tb_log_dir=agg_tb_log_dir,
            )
        ],
    )

    print(f"\n=== Training {label} ===")
    start = time.perf_counter()
    trainer.fit(pipeline)
    duration = time.perf_counter() - start
    print(f"Done {label} in {duration:.2f}s")
    eval_results = evaluate_targets(
        label,
        target_cgs,
        target_ctms,
        queries_by_target,
        shared_templates_mod=shared_templates_mod,
        hyperparams=hyperparams,
        role_aggregators=role_aggregators,
        reps=reps,
        cg_ref=cg_ref,
    )
    return duration, eval_results


if __name__ == "__main__":
    RANDOM_SEED = 7
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
    parser.add_argument("--accelerator", default="gpu", help="Trainer accelerator")
    parser.add_argument("--devices", type=int, default=1, help="Number of devices to use")
    parser.add_argument("--gpu-id", type=int, default=None, help="GPU id to use (overrides --devices)")
    parser.add_argument("--max-epochs", type=int, default=200, help="Max training epochs")
    parser.add_argument("--limit-train-batches", type=float, default=1.0)
    parser.add_argument("--num-trials", type=int, default=1)
    parser.add_argument("--trial-idx", type=int, default=None, help="Run only this trial index (1-based).")
    parser.add_argument(
        "--run-target-only-baseline",
        action="store_true",
        help="Run target-only MLE baseline runs.",
    )
    parser.add_argument(
        "--baseline-only",
        action="store_true",
        help="Skip RNCM training and run only the target-only baseline.",
    )
    parser.add_argument(
        "--role-agg",
        choices=["strict_maj", "weak_maj", "or", "and", "min", "max", "sum", "mean", "count"],
        default="strict_maj",
        help="Aggregator for all role nodes (count encodes as bits; capped at 5).",
    )
    parser.add_argument(
        "--sweep-lr",
        action="store_true",
        help="Enable learning rate sweep using --lr-sweep values.",
    )
    parser.add_argument(
        "--lr-sweep",
        nargs="+",
        type=float,
        default=[4e-3, 1e-3, 3e-4, 1e-4],
        help="Learning rates to sweep.",
    )
    parser.add_argument(
        "--sources",
        nargs="+",
        default=["A", "B", "C", "A+B", "A+B+C"],
        help="Source combos to train (e.g., A B C A+B).",
    )
    args = parser.parse_args()
    if not args.sweep_lr:
        args.lr_sweep = [1e-3]

    cg_dir = Path(__file__).resolve().parent / "NCMCounterfactuals" / "dat" / "cg"
    cg_a = cg_dir / "less_simple_trafficA.cg"
    cg_b = cg_dir / "less_simple_trafficB.cg"
    cg_c = cg_dir / "less_simple_trafficC.cg"
    naive_cg_a = cg_dir / "naive_less_simple_trafficA.cg"
    naive_cg_b = cg_dir / "naive_less_simple_trafficB.cg"
    naive_cg_c = cg_dir / "naive_less_simple_trafficC.cg"

    role_aggregators = {
        "R_W_X": args.role_agg,
        "R_W_Y": args.role_agg,
        "R_X_B": args.role_agg,
        "R_Y_B": args.role_agg,
        "R_W_B": args.role_agg,
    }

    hyperparams = {
        "lr": 1e-3,
        "data-bs": 1000,
        "ncm-bs": 1000,
        "h-layers": 2,
        "h-size": 128,
        "u-size": 1,
        "max-query-iters": 300,
        "min-lambda": 0.001,
        "max-lambda": 1.0,
        "mc-sample-size": 5000,
        "full-batch": False,
    }

    devices = [args.gpu_id] if args.gpu_id is not None else args.devices
    trainer_cfg = {
        "max_epochs": args.max_epochs,
        "limit_train_batches": args.limit_train_batches,
        "accelerator": args.accelerator,
        "devices": devices,
        "n_samples": 10000,
        "eval_n": 10000,
    }

    if args.debug:
        trainer_cfg.update({"max_epochs": 10, "limit_train_batches": 2, "n_samples": 500})
        hyperparams["mc-sample-size"] = 100
        hyperparams["ncm-bs"] = 200
        hyperparams["fast-counts"] = True
        hyperparams["fast-counts-n"] = 200
        hyperparams["profile-counts"] = True
        hyperparams["profile-likelihood"] = True
        hyperparams["profile-every"] = 1

    target_cgs = {
        "A": read_cg_with_metadata(str(cg_a)),
        "B": read_cg_with_metadata(str(cg_b)),
        "C": read_cg_with_metadata(str(cg_c)),
    }
    naive_target_cgs = {
        "A": read_cg_with_metadata(str(naive_cg_a)),
        "B": read_cg_with_metadata(str(naive_cg_b)),
        "C": read_cg_with_metadata(str(naive_cg_c)),
    }
    queries_by_target = {
        "A": [
            ("C0_B", ["P0_X", "P1_X", "S0_W"]),
            ("C1_B", ["P0_X"]),
        ],
        "B": [
            ("C0_B", ["S0_W", "S1_W", "P0_X", "P1_X"]),
            ("C1_B", ["S1_W", "P1_X"]),
        ],
        "C": [
            ("C0_B", ["S0_W", "S1_W", "P0_X", "P1_X"]),
            ("C1_B", ["P1_X", "P2_X", "S1_W"]),
            ("C2_B", ["P2_X"]),
        ],
    }

    cg_map = {"A": cg_a, "B": cg_b, "C": cg_c}
    durations = {}
    if args.trial_idx is not None:
        trial_range = [args.trial_idx]
    else:
        trial_range = range(1, args.num_trials + 1)
    for trial_idx in trial_range:
        trial_seed = RANDOM_SEED + trial_idx - 1
        random.seed(trial_seed)
        np.random.seed(trial_seed)
        T.manual_seed(trial_seed)
        pl.seed_everything(trial_seed, workers=True)
        base_ctm_cg = next(iter(target_cgs.values()))
        base_ctm = build_rctm(base_ctm_cg, role_aggregators, trial_seed)
        shared_template_funcs = base_ctm.template_funcs
        target_ctms = {
            name: build_rctm(
                cg, role_aggregators, seed=None, template_funcs=shared_template_funcs
            )
            for name, cg in target_cgs.items()
        }
        do_var_list = [{}]
        target_obs_data = {
            name: [ctm(n=trainer_cfg["n_samples"], do=do_set) for do_set in do_var_list]
            for name, ctm in target_ctms.items()
        }
        if args.debug:
            p_func_ids = {
                name: id(ctm.template_funcs["P_X"]["func"])
                for name, ctm in target_ctms.items()
            }
            print(f"[debug] trial {trial_idx} P_X template func ids:", p_func_ids)
        for lr in args.lr_sweep:
            hyperparams["lr"] = lr
            for target_name, target_path in cg_map.items():
                if args.run_target_only_baseline:
                    baseline_label = f"naive_target_only_{target_name}"
                    run_id = build_run_id(
                        baseline_label,
                        [target_path],
                        target_path,
                        hyperparams,
                        trainer_cfg,
                        trial_idx,
                        trial_seed,
                        queries_by_target,
                        args.role_agg,
                    )
                    lock_path = lock_path_for(run_id)
                    results_dir = results_dir_for(baseline_label, trial_idx)
                    safe_label = baseline_label.replace("+", "_")
                    results_path = results_dir / f"results_source{safe_label}__{run_id}.json"
                    hyperparams_path = results_dir / f"hyperparams_source{safe_label}__{run_id}.json"
                    if results_path.exists():
                        print(
                            f"Skipping {baseline_label} trial {trial_idx} lr={lr}: "
                            f"results already exist at {results_path}"
                        )
                        continue
                    if not acquire_lock(lock_path):
                        print(
                            f"Skipping {baseline_label} trial {trial_idx} lr={lr}: "
                            f"lock held at {lock_path}"
                        )
                        continue
                    try:
                        safe_role_agg = str(args.role_agg).replace("+", "_")
                        tb_log_dir = (
                            results_dir / "tb" / f"role_{safe_role_agg}" / f"lr_{lr:.0e}"
                        )
                        logger = TensorBoardLogger(
                            save_dir=str(tb_log_dir),
                            name="",
                            version="",
                            default_hp_metric=False,
                        )
                        trainer_cfg["logger"] = logger
                        naive_cg = naive_target_cgs[target_name]
                        allowed_keys = set(naive_cg)
                        baseline_generator = ProjectedCTM(
                            target_ctms[target_name], allowed_keys
                        )
                        baseline_dat_sets = [
                            filter_dataset(dat, allowed_keys)
                            for dat in target_obs_data[target_name]
                        ]
                        try:
                            duration, eval_results = run_mle_baseline(
                                baseline_label,
                                target_name,
                                naive_cg,
                                baseline_generator,
                                target_ctms[target_name],
                                hyperparams,
                                trainer_cfg=trainer_cfg,
                                queries_by_target=queries_by_target,
                                dat_sets=baseline_dat_sets,
                                do_var_list=do_var_list,
                            )
                            durations[f"{baseline_label}:trial{trial_idx}:lr{lr}"] = duration
                            hp_subset = {
                                "lr": hyperparams.get("lr"),
                                "h-size": hyperparams.get("h-size"),
                                "h-layers": hyperparams.get("h-layers"),
                                "u-size": hyperparams.get("u-size"),
                                "mc-sample-size": hyperparams.get("mc-sample-size"),
                                "ncm-bs": hyperparams.get("ncm-bs"),
                                "data-bs": hyperparams.get("data-bs"),
                            }
                            payload = {
                                "run_id": run_id,
                                "combo": baseline_label,
                                "trial_idx": trial_idx,
                                "seed": trial_seed,
                                "sources": [str(target_path)],
                                "target": str(target_path),
                                "hyperparams": hp_subset,
                                "trainer_cfg": {
                                    "max_epochs": trainer_cfg.get("max_epochs"),
                                    "limit_train_batches": trainer_cfg.get("limit_train_batches"),
                                    "accelerator": trainer_cfg.get("accelerator"),
                                    "devices": trainer_cfg.get("devices"),
                                    "n_samples": trainer_cfg.get("n_samples"),
                                    "eval_n": trainer_cfg.get("eval_n"),
                                },
                                "queries_by_target": queries_by_target,
                                "duration_sec": duration,
                                "eval_results": eval_results,
                                "completed_at": time.time(),
                            }
                            write_json_atomic(results_path, payload)
                            write_json_atomic(
                                hyperparams_path,
                                {
                                    "hyperparams": hp_subset,
                                    "trainer_cfg": {
                                        "max_epochs": trainer_cfg.get("max_epochs"),
                                        "limit_train_batches": trainer_cfg.get("limit_train_batches"),
                                        "accelerator": trainer_cfg.get("accelerator"),
                                        "devices": trainer_cfg.get("devices"),
                                        "n_samples": trainer_cfg.get("n_samples"),
                                        "eval_n": trainer_cfg.get("eval_n"),
                                    },
                                },
                            )
                            print(f"Wrote baseline results to {results_path}")
                        except KeyboardInterrupt:
                            cleanup_paths = [
                                results_path,
                                hyperparams_path,
                                results_path.with_suffix(results_path.suffix + ".tmp"),
                                hyperparams_path.with_suffix(hyperparams_path.suffix + ".tmp"),
                            ]
                            for path in cleanup_paths:
                                try:
                                    path.unlink()
                                except FileNotFoundError:
                                    pass
                            shutil.rmtree(tb_log_dir, ignore_errors=True)
                            raise
                    finally:
                        release_lock(lock_path)
                if args.baseline_only:
                    continue
                for combo in args.sources:
                    parts = [p.strip() for p in combo.split("+") if p.strip()]
                    if not parts:
                        continue
                    try:
                        source_paths = [cg_map[p] for p in parts]
                    except KeyError as exc:
                        raise ValueError(
                            f"Unknown source '{exc.args[0]}' in combo '{combo}'."
                        ) from exc
                    if source_paths[0] != target_path:
                        continue
                    target_path = source_paths[0]
                    run_id = build_run_id(
                        combo,
                        source_paths,
                        target_path,
                        hyperparams,
                        trainer_cfg,
                        trial_idx,
                        trial_seed,
                        queries_by_target,
                        args.role_agg,
                    )
                    lock_path = lock_path_for(run_id)
                    results_dir = results_dir_for(combo, trial_idx)
                    safe_label = combo.replace("+", "_")
                    results_path = results_dir / f"results_source{safe_label}__{run_id}.json"
                    hyperparams_path = results_dir / f"hyperparams_source{safe_label}__{run_id}.json"
                    if results_path.exists():
                        print(
                            f"Skipping {combo} trial {trial_idx} lr={lr}: results already exist at {results_path}"
                        )
                        continue
                    if not acquire_lock(lock_path):
                        print(
                            f"Skipping {combo} trial {trial_idx} lr={lr}: lock held at {lock_path}"
                        )
                        continue
                    try:
                        safe_role_agg = str(args.role_agg).replace("+", "_")
                        tb_log_dir = (
                            results_dir / "tb" / f"role_{safe_role_agg}" / f"lr_{lr:.0e}"
                        )
                        tb_agg_dir = (
                            results_dir
                            / "tb_agg"
                            / f"role_{safe_role_agg}"
                            / f"lr_{lr:.0e}"
                        )
                        agg_state_path = tb_agg_dir / "mse_across_trials.json"
                        agg_tb_log_dir = tb_agg_dir / "events"
                        logger = TensorBoardLogger(
                            save_dir=str(tb_log_dir),
                            name="",
                            version="",
                            default_hp_metric=False,
                        )
                        trainer_cfg["logger"] = logger
                        try:
                            duration, eval_results = run_training(
                                combo,
                                source_paths,
                                target_path=target_path,
                                hyperparams=hyperparams,
                                role_aggregators=role_aggregators,
                                seed=trial_seed,
                                trainer_cfg=trainer_cfg,
                                target_cgs=target_cgs,
                                target_ctms=target_ctms,
                                queries_by_target=queries_by_target,
                                template_funcs=shared_template_funcs,
                                agg_state_path=agg_state_path,
                                agg_tb_log_dir=agg_tb_log_dir,
                            )
                            durations[f"{combo}:trial{trial_idx}:lr{lr}"] = duration
                            hp_subset = {
                                "lr": hyperparams.get("lr"),
                                "h-size": hyperparams.get("h-size"),
                                "h-layers": hyperparams.get("h-layers"),
                                "u-size": hyperparams.get("u-size"),
                                "mc-sample-size": hyperparams.get("mc-sample-size"),
                                "ncm-bs": hyperparams.get("ncm-bs"),
                                "data-bs": hyperparams.get("data-bs"),
                            }
                            payload = {
                                "run_id": run_id,
                                "combo": combo,
                                "trial_idx": trial_idx,
                                "seed": trial_seed,
                                "sources": [str(p) for p in source_paths],
                                "target": str(target_path),
                                "hyperparams": hp_subset,
                                "trainer_cfg": {
                                    "max_epochs": trainer_cfg.get("max_epochs"),
                                    "limit_train_batches": trainer_cfg.get(
                                        "limit_train_batches"
                                    ),
                                    "accelerator": trainer_cfg.get("accelerator"),
                                    "devices": trainer_cfg.get("devices"),
                                    "n_samples": trainer_cfg.get("n_samples"),
                                    "eval_n": trainer_cfg.get("eval_n"),
                                },
                                "queries_by_target": queries_by_target,
                                "duration_sec": duration,
                                "eval_results": eval_results,
                                "completed_at": time.time(),
                            }
                            write_json_atomic(results_path, payload)
                            write_json_atomic(
                                hyperparams_path,
                                {
                                    "hyperparams": hp_subset,
                                    "trainer_cfg": {
                                        "max_epochs": trainer_cfg.get("max_epochs"),
                                        "limit_train_batches": trainer_cfg.get(
                                            "limit_train_batches"
                                        ),
                                        "accelerator": trainer_cfg.get("accelerator"),
                                        "devices": trainer_cfg.get("devices"),
                                        "n_samples": trainer_cfg.get("n_samples"),
                                        "eval_n": trainer_cfg.get("eval_n"),
                                    },
                                },
                            )
                            print(f"Wrote results to {results_path}")
                        except KeyboardInterrupt:
                            cleanup_paths = [
                                results_path,
                                hyperparams_path,
                                results_path.with_suffix(results_path.suffix + ".tmp"),
                                hyperparams_path.with_suffix(hyperparams_path.suffix + ".tmp"),
                            ]
                            for path in cleanup_paths:
                                try:
                                    path.unlink()
                                except FileNotFoundError:
                                    pass
                            shutil.rmtree(tb_log_dir, ignore_errors=True)
                            raise
                    finally:
                        release_lock(lock_path)

    print("\n=== Timing Summary ===")
    for key, dur in durations.items():
        print(f"{key}: {dur:.2f}s")
