import os
import shutil
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, Any, Tuple
import pandas as pd
import torch
import random
import pickle
import time

from circuit_utils import MlpCircuitStats, vision_circuit_stats, visualize_pruned_vision_model, \
    visualize_mlp_with_active_neurons, print_logits_for_adv_x
from config import config, env_info
from dataset_utils import (
    load_dataset, load_model, load_dataset_and_model, get_mean_activations_mnist,
    get_mean_activations_cifar10, get_mean_activations_gtsrb, get_mean_activations_taxinet, prepare_batch
)
from formal_pruning import formal_prune_mnist, \
    formal_prune_vision, find_adversarial_example, find_formal_adv_example, formal_patch_query
from formal_pruning import qmsc_formal_patching_mnist
from ab_crown_utils import load_x_from_file
from informal_pruning import informal_mnist_pruning, informal_vision_pruning, informal_mnist_kl_div_pruning
from models.model_factory import compare_model_predictions, load_mnist_model, load_cifar10_model, create_mask_from_comps, \
    load_gtsrb_model, load_taxinet_model
from contrastives import check_all_contrastive_subsets, verify_mhs_sufficiency, get_neurons_by_names, \
    get_contrastives_mhs
import glob


def create_experiment_dirs(base_dir, dataset, exp_name, exp_type, parameters):
    """Creates a clean directory structure for experiment outputs."""
    metric_clean = parameters['metric'].replace("_", "")

    param_str = "_".join(filter(None, [
        exp_name,
        parameters['prune_by'].replace("_", "") if parameters.get('prune_by') else None,
        f"{parameters['patching']}p" if parameters.get('patching') else None,
        metric_clean,
        f"e{parameters['epsilon']}" if parameters.get('epsilon') is not None else None,
        f"d{parameters['delta']}" if parameters.get('delta') is not None else None,
        f"frac{parameters['frac']}" if parameters.get('frac') is not None else None,
        f"pe{parameters['patch_eps']}" if parameters.get('patch_eps') is not None else None,
        f"tau{parameters['tau']}" if parameters.get('tau') is not None else None,
    ]))

    run_id = parameters.get('run_id')
    if run_id:
        # For parallel runs, use the run_id for a deterministic path
        dir_name = f"{param_str}_run{run_id}"
    else:
        # For single runs, use timestamp and random id for uniqueness
        rand_id = random.randint(100000, 999999)
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        dir_name = f"{param_str}_{timestamp}_id{rand_id}"

    experiment_dir = os.path.join(base_dir, dataset, exp_type, dir_name)
    os.makedirs(os.path.join(experiment_dir, 'plots'), exist_ok=True)
    os.makedirs(os.path.join(experiment_dir, 'logs'), exist_ok=True)
    os.makedirs(os.path.join(experiment_dir, 'models'), exist_ok=True)
    return experiment_dir


def create_exp_paths(models_dir, logs_dir, plots_dir, job_id=None):
    # If a job_id is provided, create worker-specific subdirectories.
    if job_id is not None:
        worker_models_dir = os.path.join(models_dir, f'worker_{job_id}')
        worker_logs_dir = os.path.join(logs_dir, f'worker_{job_id}')
        worker_plots_dir = os.path.join(plots_dir, f'worker_{job_id}')
        os.makedirs(worker_models_dir, exist_ok=True)
        os.makedirs(worker_logs_dir, exist_ok=True)
        os.makedirs(worker_plots_dir, exist_ok=True)
    else:
        worker_models_dir = models_dir
        worker_logs_dir = logs_dir
        worker_plots_dir = plots_dir

    exp_paths = {
        # Shared paths (used by aggregator and for final results)
        'base_models_dir': models_dir,
        'base_logs_dir': logs_dir,
        'base_plots_dir': plots_dir,
        'contrastive_results_dir': os.path.join(models_dir, 'contrastive_results'),

        # Worker-specific or general paths
        'path_to_save_mnist_dupnet': os.path.join(worker_models_dir, 'FORMAL_DUPNET_3_layered_mnist_10K_976.pth'),
        'path_to_save_cifar10-small_dupnet': os.path.join(worker_models_dir, 'FORMAL_DUPNET_cifar10_small_model.pth'),
        'path_to_save_cifar10-big_dupnet': os.path.join(worker_models_dir, 'FORMAL_DUPNET_cifar10_big_model.pth'),
        'path_to_save_taxinet_dupnet': os.path.join(worker_models_dir, 'FORMAL_DUPNET_taxinet_model.pth'),
        'path_to_save_gtsrb_dupnet': os.path.join(worker_models_dir, 'FORMAL_DUPNET_gtsrb_model.pth'),
        'path_to_save_mnist_tripnet': os.path.join(worker_models_dir, 'FORMAL_TRIPNET_3_layered_mnist_10K_976.pth'),
        'abcrown_specification_path': os.path.join(worker_logs_dir, 'abcrown_spec.yaml'),
        'adv_x_path': os.path.join(worker_plots_dir, 'adv_x.txt'),
        'mnist_sample_path': os.path.join(worker_models_dir, 'mnist_sample.npy'),
        'cifar10_sample_path': os.path.join(worker_models_dir, 'cifar10_sample.npy'),
        'taxinet_sample_path': os.path.join(worker_models_dir, 'taxinet_sample.npy'),
        'gtsrb_sample_path': os.path.join(worker_models_dir, 'gtsrb_sample.npy'),
        'pruned_net_path': os.path.join(worker_models_dir, 'pruned_net.pth'),
        'saved_dup_patch_net_path_mnist': os.path.join(worker_models_dir, 'dup_patch_net_mnist.pth'),
        'saved_dup_patch_net_path_cifar10-small': os.path.join(worker_models_dir, 'dup_patch_net_cifar10-small.pth'),
        'saved_dup_patch_net_path_taxinet': os.path.join(worker_models_dir, 'dup_patch_net_taxinet.pth'),
        'saved_dup_patch_net_path_gtsrb': os.path.join(worker_models_dir, 'dup_patch_net_gtsrb.pth'),
        'main_log_file': os.path.join(worker_logs_dir, 'main.log'),
        'Z_mask_file': os.path.join(worker_models_dir, 'Z_mask_file'),
        'customized_models_paths': os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', 'model.py')),
    }
    return exp_paths

## todo fix patching

informal_methods = {
    "mnist": (informal_mnist_pruning, visualize_mlp_with_active_neurons, MlpCircuitStats.print_stats,get_mean_activations_mnist),
    "cifar10-big": (informal_vision_pruning, visualize_pruned_vision_model, vision_circuit_stats, get_mean_activations_cifar10),
    "cifar10-small": (informal_vision_pruning, visualize_pruned_vision_model, vision_circuit_stats, get_mean_activations_cifar10),
    "gtsrb": (informal_vision_pruning, visualize_pruned_vision_model, vision_circuit_stats, get_mean_activations_gtsrb),
    "taxinet": (informal_vision_pruning, visualize_pruned_vision_model, vision_circuit_stats, get_mean_activations_taxinet),
}


class ExperimentRunner(ABC):
    def __init__(self, device: torch.device, logger=None):
        self.device = device
        self.base_output_dir = config['paths']['outputs_root']
        self.logger = logger
        self.name = None
        self.dataset = None
        self.exp_type = None
        self.parameters = None

    def setup_experiment(self, name: str, dataset: str, exp_type: str,
                         parameters: Dict[str, Any]) -> Tuple[str, str, str, str]:
        self.name = name
        self.dataset = dataset
        self.exp_type = exp_type
        self.parameters = parameters
        self.experiment_dir = create_experiment_dirs(self.base_output_dir, self.dataset, self.name, self.exp_type, self.parameters)
        self.plots_dir = os.path.join(self.experiment_dir, 'plots')
        self.logs_dir = os.path.join(self.experiment_dir, 'logs')
        self.models_dir = os.path.join(self.experiment_dir, 'models')
        job_id = self.parameters.get('job_id')
        self.exp_paths = create_exp_paths(self.models_dir, self.logs_dir, self.plots_dir, job_id=job_id)
        self.conf_file = os.path.join(self.logs_dir, 'conf.txt')
        with open(self.conf_file, 'w') as log_file:
            log_file.write(f"Experiment Name: {self.name}\n")
            log_file.write(f"Dataset: {self.dataset}\n")
            log_file.write(f"Type: {self.exp_type}\n")
            log_file.write(f"Parameters: {self.parameters}\n")

        return self.experiment_dir, self.plots_dir, self.logs_dir, self.models_dir

    def log(self, message: str, level=logging.INFO):
        if self.logger:
            if level == logging.INFO:
                self.logger.info(message)
            elif level == logging.ERROR:
                self.logger.error(message)
            elif level == logging.WARNING:
                self.logger.warning(message)
            else:
                self.logger.debug(message)

    def log_experiment_start(self, conf_file):
        start_time = datetime.now()
        self.start_time = start_time  # store as datetime object
        with open(conf_file, 'a') as log_file:
            log_file.write(f"env info: {env_info}")
            log_file.write(f"Start Time: {start_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}  (ms precision)\n")
        self.log(f"experiment started: {self.name if self.name is not None else ''}, dataset: {self.dataset if self.dataset is not None else ''}, type: {self.exp_type if  self.exp_type is not None else ''}, parameters: {self.parameters if self.parameters is not None else ''}")

    def log_experiment_end(self, conf_file):
        end_time = datetime.now()
        self.end_time = end_time  # store as datetime object
        elapsed_sec = (self.end_time - self.start_time).total_seconds()
        elapsed_ms = elapsed_sec * 1000
        elapsed_min = elapsed_sec / 60
        with open(conf_file, 'a') as log_file:
            log_file.write(f"End Time: {end_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}  (ms precision)\n")
            log_file.write(f"Elapsed Time: {elapsed_ms:.2f} ms  ({elapsed_min:.2f} min)\n")
        self.log(f"experiment ended: {self.name if self.name is not None else ''}, elapsed: {elapsed_sec:.2f}s")
        return elapsed_sec
    @abstractmethod
    def run(self, name: str, exp_type: str, dataset: str, parameters: Dict[str, Any]):
        pass

    def set_up_logger(self):
        self.logger = logging.getLogger()
        # remove all handlers associated with the root logger
        for handler in self.logger.handlers[:]:
            self.logger.removeHandler(handler)

        # setup logging to main.log
        logging.basicConfig(
            level=logging.INFO,
            format='[%(asctime)s][%(levelname)s] %(message)s',
            handlers=[
                logging.FileHandler(self.exp_paths['main_log_file'], mode='a'),
                logging.StreamHandler()
            ]
        )

    def save_model(self, result, save_path: str):
        if result is not None:
            torch.save(result.state_dict(), save_path)
            logging.info(f"model saved to: {save_path}")


class FormalPruningRunner(ExperimentRunner):
    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        if dataset not in ["mnist", "cifar10-big", "cifar10-small", "gtsrb", "taxinet"]:
            raise ValueError(f"Unsupported dataset: {dataset}. Supported datasets are: mnist, cifar10, cifar10-big, cifar10-small, gtsrb, taxinet")

        sample_ids = parameters.get('sample_ids', None)
        ## TODO - load data once for efficiency
        full_net, model_path, train_gen, _, _, _, X, winner_runner_logit_diff = load_dataset_and_model(dataset,
                                                                                                       self.device, self.exp_paths, sample_ids=sample_ids)
        conf_file = self.conf_file
        save_model_path = os.path.join(self.models_dir, f"{exp_type}_model.pth")
        plot_path = os.path.join(self.plots_dir, f"{name}_pruned_net.png")
        stats_path = os.path.join(self.plots_dir, f"{name}_stats.txt")

        self.log(f"running formal pruning: {name}, dataset: {dataset}, parameters: {parameters}")
        return self.run_formal_pruning(dataset, full_net, model_path, parameters, train_gen, X, self.exp_paths,
                                       conf_file, save_model_path, plot_path, stats_path, sample_ids)

    def run_formal_pruning(self, dataset, full_net, full_net_path, parameters, train_gen, X, exp_paths, conf_file,
                           save_model_path, plot_path=None, stats_path=None, sample_ids=None):
        patching = parameters.get('patching')

        if dataset == "mnist":
            experiment_function, visualize_function, stats_function = formal_prune_mnist, visualize_mlp_with_active_neurons, MlpCircuitStats.print_stats
        elif dataset in  ["cifar10-big", "cifar10-small", "gtsrb", "taxinet"]:
            experiment_function, visualize_function, stats_function = formal_prune_vision, visualize_pruned_vision_model, vision_circuit_stats
        if patching == "mean":
            if dataset in ["cifar10-big", "cifar10-small"]:
                get_mean_activations_func = get_mean_activations_cifar10
            elif dataset == 'taxinet':
                get_mean_activations_func = get_mean_activations_taxinet
            elif dataset == "gtsrb":
                get_mean_activations_func = get_mean_activations_gtsrb
            else:
                get_mean_activations_func = get_mean_activations_mnist
            mean_activations, weight_contrib = get_mean_activations_func(full_net, train_gen, self.device, num_samples=100)
            parameters.update({"data_dist": mean_activations, "weight_contrib": weight_contrib})
        self.log_experiment_start(conf_file)
        use_quasi = parameters.get('use_quasi')
        self.log(f"starting formal pruning function: {'qmsc_formal_patching_mnist' if use_quasi else experiment_function.__name__}")
        if use_quasi:
            pruned_net, components, qsmc_info = qmsc_formal_patching_mnist(
                dataset=dataset,
                full_net=full_net,
                X=X,
                full_net_path=full_net_path,
                device=self.device,
                exp_paths=exp_paths,
                **parameters
            )
            timeouts = []  # qmsc path does not return timeouts
            # Attach quasi-monte-carlo search meta info into components for downstream logging.
            if isinstance(components, dict):
                components['qsmc_info'] = qsmc_info
            self.log(f"qmsc summary: evaluations={qsmc_info['evaluations']}, yes_k={qsmc_info['yes_k']}, breaking_neuron={qsmc_info['breaking_neuron']}")
        else:
            pruned_net, components, timeouts = experiment_function(
                dataset = dataset,
                full_net=full_net,
                X=X,
                full_net_path=full_net_path,
                device=self.device,
                exp_paths=exp_paths,
                **parameters
            )
        self.log(f"formal pruning done. pruned_net: {type(pruned_net)}, components: {components}, timeouts: {timeouts}")
        elapsed_s = self.log_experiment_end(conf_file)
        self.save_model(pruned_net, save_model_path)
        if plot_path is not None:
            visualize_function(model=pruned_net, plot_path=plot_path, components=components, prune_by=parameters.get("prune_by"))
        if stats_path is not None:
            stats_function(dataset, full_net, pruned_net, X, self.device, save_to_path=stats_path, components=components, timeouts=timeouts, sample_ids=sample_ids)
        return pruned_net, components, elapsed_s, timeouts


class InformalPruningRunner(ExperimentRunner):
    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        if dataset not in ["mnist", "cifar10-big", "cifar10-small", "gtsrb", "taxinet"]:
            raise ValueError(f"Unsupported dataset: {dataset}. Must be one of: mnist, cifar10, cifar10-big, cifar10-small, gtsrb, taxinet.")

        sample_ids = parameters.get('sample_ids', None)

        full_net, _, train_gen, _, _, _, X, winner_runner_logit_diff = load_dataset_and_model(dataset, self.device, self.exp_paths, sample_ids=sample_ids)
        plot_path = os.path.join(self.plots_dir, f"{name}_pruned_net.png")
        stats_path = os.path.join(self.plots_dir, f"{name}_stats.txt")
        save_model_path = os.path.join(self.models_dir, f"{name}_model.pth")
        self.log(f"running informal pruning: {name}, dataset: {dataset}, parameters: {parameters}")
        return self.run_informal_pruning(dataset, full_net, parameters, train_gen, X, self.conf_file, save_model_path, plot_path, stats_path, sample_ids)

    def run_informal_pruning(self, dataset, full_net, parameters, train_gen, X, conf_file, save_model_path,
                             plot_path=None, stats_path=None, sample_ids=None):
        experiment_function, visualize_function, stats_function, get_mean_activations_func = informal_methods[dataset]
        if dataset == 'mnist' and parameters['metric'] == 'kl_div':
            experiment_function = informal_mnist_kl_div_pruning
        if parameters.get("patching") == "mean":
            mean_activations, weight_contrib = get_mean_activations_func(full_net, train_gen, self.device,
                                                                         num_samples=100)
            parameters.update({"data_dist": mean_activations, "weight_contrib": weight_contrib})
        self.log_experiment_start(conf_file)
        self.log(f"starting informal pruning function: {experiment_function.__name__}")
        pruned_net, components = experiment_function(dataset=dataset, full_net=full_net, X=X, device=self.device,  **parameters)
        self.log(f"informal pruning done. pruned_net: {type(pruned_net)}, components: {components}")
        elapsed_s = self.log_experiment_end(conf_file)
        self.save_model(pruned_net, save_model_path)
        if plot_path is not None:
            visualize_function(model=pruned_net,  plot_path=plot_path, components=components)
        if stats_path is not None:
            stats_function(dataset, full_net, pruned_net, X, self.device, save_to_path=stats_path, components=components, sample_ids=sample_ids)
        return pruned_net, components, elapsed_s

### todo modify adversarial to work with batch
class AdversarialExample(ExperimentRunner):
    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        if dataset not in ["mnist", "cifar10-big", "cifar10-small", "gtsrb"]:
            raise ValueError(f"Unsupported dataset: {dataset}. Must be one of: mnist, cifar10, cifar10-big, cifar10-small.")


        full_net, full_net_path, train_gen, _, _, _, x, winner_runner_logit_diff = load_dataset_and_model(dataset, self.device, self.exp_paths)

        self.log_experiment_start(self.conf_file)
        adv_x = find_adversarial_example(dataset=dataset, x=x, full_net_path=full_net_path, device=self.device, adv_x_path=self.exp_paths['adv_x_path'], **parameters)
        self.log_experiment_end(self.conf_file)

        # copying models to experiment's output/models dir
        shutil.copy(parameters['informal_pruned_net_path'],
                    os.path.join(self.models_dir, f"{exp_type}_informal_pruned_net.pth"))
        shutil.copy(parameters['formal_pruned_net_path'],
                    os.path.join(self.models_dir, f"{exp_type}_formal_pruned_net.pth"))
        if adv_x is None:
            raise ValueError("Adversarial example not found")

        compare_model_predictions(parameters['formal_pruned_net_path'], parameters['informal_pruned_net_path'], adv_x,
                                  self.device, dataset, save_plot_path=os.path.join(self.plots_dir, "adv_x.png"),
                                  save_stats_path=os.path.join(self.plots_dir, f"adv_x_logits.txt"))



class CollectPatching(ExperimentRunner):
    """An experiment runner that collects and evaluates circuits with different patching methods: zero, mean, and formal."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.records = []
        self.csv_path = None

    def _make_suffixes(self, dataset: str, batch: list, method: str, tag: str = None): ## TODO merge with _make_suffixes in collect_circuits
        batch_id = "_".join(str(sid) for sid in batch)
        base = f"{dataset}_batch_{batch_id}_{method}"
        if tag: base += f"_{tag}"
        return (f"{base}.pth", f"{base}_pruned_net.png", f"{base}_stats.txt", f'{base}_adv_x.txt')

    def record(self, row):
        # lazy init CSV path
        if self.csv_path is None:
            exp_dir_name = os.path.basename(self.experiment_dir.rstrip("/"))
            self.csv_path = os.path.join(self.plots_dir, f"{self.current_dataset}_{exp_dir_name}_results.csv")
        self.records.append(row)
        pd.DataFrame([row]).to_csv(self.csv_path, mode='a', header=not os.path.exists(self.csv_path), index=False)

    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        self.current_dataset = dataset
        self.current_exp_type = exp_type

        batches = parameters["batches"]
        train_gen, test_data, test_gen = load_dataset(dataset)
        model, model_path = load_model(dataset, self.device)
        metric = parameters['metric']
        patch_methods = parameters.get('patch_methods', ['zero', 'mean', 'formal'])
        patch_epsilons = parameters['patch_eps']
        patch_epsilons = patch_epsilons if isinstance(patch_epsilons, list) else [patch_epsilons]
        epsilon = parameters.get('epsilon')
        verify_patching_only = bool(parameters.get("verify_patching_only", True))
        self.log(f"verify_patching_only: {verify_patching_only}, patch_epsilons: {patch_epsilons}, epsilon: {epsilon}, metric: {metric}")

        for batch in batches:
            x, wrld = prepare_batch(dataset, model, test_data, self.exp_paths, self.device, batch)
            if metric == 'winner_runner':
                for frac in parameters.get('frac'):
                    deltas = torch.floor(wrld * frac * 100) / 100  # compute batched (winner-runner)-delta for each sample
                    deltas = deltas.to(self.device)  # ensure delta is on the same device
                    for patch_eps in patch_epsilons:
                        run_params = dict(parameters, delta=deltas, frac=frac, epsilon=epsilon, patch_eps=patch_eps, verify_patching_only=verify_patching_only)
                        tag = f"delta{frac:.2f}_patchEps{patch_eps}"
                        self._process_batch(dataset, model, model_path, train_gen, x, run_params, batch, patch_methods, tag)
            else:
                deltas = parameters.get('delta')
                deltas = deltas if isinstance(deltas, list) else [deltas]
                for delta in deltas:
                    for patch_eps in patch_epsilons:
                        run_params = dict(parameters, delta=delta, epsilon=epsilon, patch_eps=patch_eps, verify_patching_only=verify_patching_only)
                        tag = f"delta{delta}_patchEps{patch_eps}"
                        self._process_batch(dataset, model, model_path, train_gen, x, run_params, batch, patch_methods, tag)
                        print("#" * 80)

    def _process_batch(self, dataset, model, model_path, train_gen, x, run_params, batch, patch_methods, tag):
        formal_patching_net_touts_comps = None
        results = {}
        for method in patch_methods:
            status = None
            save_suffix, plot_suffix, stats_suffix, adv_x_suffix = self._make_suffixes(dataset, batch, f"{tag}_{method}") # Append method name to tag for uniqueness
            if method == 'formal':
                try:
                    pruned_net, comps, elapsed, formal_patching_net_touts_comps = FormalPruningRunner(self.device, self.logger).run_formal_pruning(
                        dataset, model, model_path, {**run_params, 'patching': 'formal'},
                        train_gen, x, self.exp_paths, self.conf_file,
                        os.path.join(self.models_dir, save_suffix),
                        os.path.join(self.plots_dir, plot_suffix),
                        os.path.join(self.plots_dir, stats_suffix),
                        sample_ids=batch
                    )
                    status = 'success'
                except Exception as e:
                    logging.error(f"Formal pruning failed for batch {batch} with method '{method}': {e}")
                    pruned_net, comps, elapsed, formal_patching_net_touts_comps = None, None, None, None
                    status = 'failure'
            elif method in ['zero', 'mean']:
                pruner = InformalPruningRunner(self.device, self.logger)
                pruned_net, comps, elapsed = pruner.run_informal_pruning( dataset, model,
                    {**run_params, 'patching': method}, train_gen, x, self.conf_file,
                    os.path.join(self.models_dir, save_suffix), os.path.join(self.plots_dir, plot_suffix),
                    os.path.join(self.plots_dir, stats_suffix), sample_ids=batch)
                status = 'success' if pruned_net is not None else 'failure'
            # Save the pruned model regardless of method
            if status == 'success':
                circuit_size = len(comps.get('active', [])) if comps and comps.get('active') is not None else 0
                robustness = self._verify_patching_robustness(comps, dataset, method, model_path, run_params, x, os.path.join(self.plots_dir, adv_x_suffix))
                results[method] = {"pruned_net": pruned_net, "comps": comps, "elapsed": elapsed,
                                  'size':circuit_size, 'save_suffix': save_suffix, 'robustness': robustness}

        formal_touts_amount = len(formal_patching_net_touts_comps) if formal_patching_net_touts_comps is not None else None

        # Compute pairwise differences between methods
        row = {
            'batch': "_".join(str(sid) for sid in batch),
            'delta': run_params['delta'],
            'patch_eps': run_params['patch_eps'],
            'formal_patching_tout_comps': formal_patching_net_touts_comps,
            'formal_patching_touts_amount': formal_touts_amount,
            'tag': tag
        }
        # adding quasi columns
        if 'formal' in results and results['formal'].get('comps'):
            qsmc_info = results['formal']['comps'].get('qsmc_info')
            row['qmsc_evaluations'] = qsmc_info.get('evaluations') if qsmc_info else None
            row['qmsc_yes_k'] = qsmc_info.get('yes_k') if qsmc_info else None
            row['qmsc_breaking_neuron'] = qsmc_info.get('breaking_neuron') if qsmc_info else None

        for method, result in results.items():
            row[f'{method}_size'] = result['size']
            row[f'{method}_time'] = result['elapsed']
            row[f'{method}_robustness'] = result['robustness']
        self.record(row)

    def _verify_patching_robustness(self, comps, dataset, method, model_path, run_params, x, adv_x_suffix):

        # No dead comps --> full network. No need to run verifier.
        if len(comps['dead']) == 0:
            logging.info("no dead components found, marking as robust")
            return 'robust'

        try:
            logging.info(f"Verifying {method} patching robustness, patch_eps: {run_params['patch_eps']}, delta: {run_params['delta']}, metric: {run_params['metric']}, full_model_path: {model_path}")
            mask = create_mask_from_comps(dataset, comps)
            is_robust, ver_res = formal_patch_query(dataset, None, x, model_path, self.exp_paths[f'saved_dup_patch_net_path_{dataset}'],
                                                    mask, self.device, run_params['metric'], run_params['epsilon'], run_params['delta'],
                                                    run_params['patch_eps'], self.exp_paths, adv_x_suffix, verify_patching_only=run_params["verify_patching_only"], query_timeout=45)
            if ver_res['unsafe'] > 0: result = 'not-robust'
            elif ver_res['timeout'] > 0: result = 'timeout'
            else: result = 'robust'

        except Exception as e:
            logging.error(f"Robustness check failed for method '{method}': {e}")
            result = 'error'
        logging.info(f"ROBUSTNESS check for method '{method}': {result}")
        return result


class CollectCircuits(ExperimentRunner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.records = []
        self.csv_path = None

    def _make_suffixes(self, dataset: str, batch: list, tag: str = None):
        batch_id = "_".join(str(sid) for sid in batch)
        base = f"{dataset}_batch_{batch_id}"
        if tag: base += f"_{tag}"
        return (f"{base}.pth", f"{base}_pruned_net.png", f"{base}_stats.txt",)

    def record(self, row):
        # lazy init CSV path
        if self.csv_path is None:
            exp_dir_name = os.path.basename(self.experiment_dir.rstrip("/"))
            self.csv_path = os.path.join(self.plots_dir, f"{self.current_dataset}_{exp_dir_name}_results.csv")
        self.records.append(row)
        pd.DataFrame([row]).to_csv(self.csv_path, mode='a', header=not os.path.exists(self.csv_path), index=False)

    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        self.current_dataset = dataset
        self.current_exp_type = exp_type
        batches = parameters.get('batches')

        train_gen, test_data, test_gen = load_dataset(dataset)
        model, model_path = load_model(dataset, self.device)
        metric = parameters.get('metric')
        fracs = parameters.get('frac')
        epsilons = parameters.get('epsilon')
        epsilon_list = epsilons if isinstance(epsilons, list) else [epsilons]

        for batch in batches:
            # wrld can be a tensor of shape (batch_size,)
            x, wrld = prepare_batch(dataset, model, test_data, self.exp_paths, self.device, batch)
            if metric == 'winner_runner':
                for frac in fracs:
                    deltas = torch.floor(wrld * frac * 100) / 100  # compute batched (winner-runner)-delta for each sample
                    deltas = deltas.to(self.device)  # ensure delta is on the same device
                    for epsilon in epsilon_list:
                        run_params = dict(parameters, delta=deltas, frac=frac, epsilon=epsilon)
                        frac_tag = f"delta{frac:.2f}_eps{epsilon}" if epsilon is not None else f"delta{frac:.2f}"
                        save_suffix, plot_suffix, stats_suffix = self._make_suffixes(dataset, batch, frac_tag)
                        self._prune_and_verify(dataset, model, model_path, train_gen, x, run_params, batch, save_suffix, plot_suffix, stats_suffix)
            else:
                deltas = parameters['delta'] if isinstance(parameters['delta'], list) else [parameters['delta']]
                for delta in deltas:
                    for epsilon in epsilon_list:
                        run_params = dict(parameters, delta=delta, epsilon=epsilon)
                        delta_tag = f"delta{delta}_eps{epsilon}" if epsilon is not None else f"delta{delta}"
                        save_suffix, plot_suffix, stats_suffix = self._make_suffixes(dataset, batch, delta_tag)
                        self._prune_and_verify(dataset, model, model_path, train_gen, x, run_params, batch, save_suffix, plot_suffix, stats_suffix)
                        print("#" * 80)

    def _prune_and_verify(self, dataset, model, full_net_path, train_gen, x, run_params, batch, save_suffix, plot_suffix, stats_suffix):

        # informal pruning
        inf_net, inf_comps, elapsed_inf = InformalPruningRunner(self.device, self.logger).run_informal_pruning(
            dataset,
            model,
            run_params,
            train_gen,
            x,
            self.conf_file,
            os.path.join(self.models_dir, "informal_" + save_suffix),
            os.path.join(self.plots_dir, "informal_" + plot_suffix),
            os.path.join(self.plots_dir, "informal_" + stats_suffix),
            sample_ids=batch
        )
        inf_active_comps, inf_dead_comps = inf_comps.get('active', []), inf_comps.get('dead')

        inf_status = 'success' if inf_net is not None else 'failure'

        informal_robust = None
        if inf_status == 'success':
            informal_robust = self._check_robustness(
                dataset=dataset,
                pruned_net_path=os.path.join(self.models_dir, "informal_" + save_suffix),
                full_net_path=full_net_path,
                full_net=model,
                sample_id="_".join(str(sid) for sid in batch),
                delta=run_params.get('delta'),
                epsilon=run_params.get('epsilon'),
                metric=run_params.get('metric'),
                X=x,
                adv_prefix="informal_adv_x",
                dead_comps=inf_dead_comps,
                frac=run_params.get('frac')
            )
            logging.info(f"informal pruning robustness for batch {batch} e={run_params.get('epsilon')} delta={run_params.get('delta')}: {informal_robust}")

        try:
            f_net, f_comps, elapsed_f, touts = FormalPruningRunner(self.device, self.logger).run_formal_pruning(dataset,
                                                                                                                model,
                                                                                                                full_net_path,
                                                                                                                run_params,
                                                                                                                train_gen,
                                                                                                                x,
                                                                                                                self.exp_paths,
                                                                                                                self.conf_file,
                                                                                                                os.path.join(self.models_dir,"formal_" + save_suffix),
                                                                                                                os.path.join(self.plots_dir,"formal_" + plot_suffix),
                                                                                                                os.path.join(self.plots_dir,"formal_" + stats_suffix),
                                                                                                                sample_ids=batch)
            f_active_comps, f_dead_comps = f_comps.get('active', []), f_comps.get('dead')
            f_status = 'success'
        except Exception as e:
            f_net, f_active_comps, elapsed_f, touts = None, None, None, None
            f_status = 'failure'
            logging.error(f"formal failed for batch {batch}: {e}")

        logging.info(f"formal pruning result for batch {batch} e={run_params.get('epsilon')} delta={run_params.get('delta')}: {f_status}")
        formal_robust = None
        if f_status == 'success':
            formal_robust = self._check_robustness(
                dataset=dataset,
                pruned_net_path=os.path.join(self.models_dir,"formal_" + save_suffix),
                full_net_path=full_net_path,
                full_net=model,
                sample_id="_".join(str(sid) for sid in batch),
                delta=run_params.get('delta'),
                epsilon=run_params.get('epsilon'),
                metric=run_params.get('metric'),
                X=x,
                adv_prefix="formal_adv_x",
                dead_comps=f_dead_comps,
                frac=run_params.get('frac')
            )
            logging.info(f"formal pruning robustness for batch {batch} e={run_params.get('epsilon')} delta={run_params.get('delta')}: {formal_robust}")

        informal_size = len(inf_active_comps) if inf_active_comps is not None else None
        formal_size = len(f_active_comps) if f_active_comps is not None else None
        size_diff = (formal_size - informal_size) if (formal_size is not None and informal_size is not None) else None
        time_diff = (elapsed_f - elapsed_inf) if (elapsed_f is not None and elapsed_inf is not None) else None
        formal_touts_amount = len(touts) if touts is not None else None
        formal_touts_percentage = (formal_touts_amount / formal_size) if formal_size is not None and formal_touts_amount is not None and formal_size > 0 else None
        formal_tout_comps = str(touts) if touts is not None else None

        row = {
            'metric': run_params.get('metric'),
            'batch': batch,  # Save as list, not string
            'delta': run_params.get('delta'),
            'frac': run_params.get('frac'),
            'epsilon': run_params.get('epsilon'),
            'informal_status': inf_status,
            'informal_time': elapsed_inf,
            'informal_size': informal_size,
            'informal_save_suffix': save_suffix,
            'informal_robust': informal_robust,
            'formal_status': f_status,
            'formal_time': elapsed_f,
            'formal_size':formal_size,
            'formal_save_suffix': save_suffix,
            'formal_robust': formal_robust,
            'formal_touts_amount': formal_touts_amount,
            'formal_tout_comps': formal_tout_comps,
            'formal_touts%': formal_touts_percentage,
            'size_diff': size_diff,
            'time_diff': time_diff,
        }
        self.record(row)

    def _check_robustness(self, dataset: str,
                          pruned_net_path: str,
                          full_net_path: str,
                          full_net,
                          sample_id: int,
                          delta: float,
                          epsilon: float,
                          metric: str,
                          X,
                          adv_prefix: str,
                          dead_comps,
                          frac=None) -> str:

        logging.info(f"checking robustness for dataset {dataset}, pruned_net_path: {pruned_net_path}, full_net_path: {full_net_path}, sample_id: {sample_id}, delta: {delta}, epsilon: {epsilon}, metric: {metric}, dead_comps: {dead_comps}")
        # No dead comps --> full network. No need to run verifier.
        if len(dead_comps) == 0:
            logging.info("no dead components found, marking as robust")
            return 'robust'

        # load pruned net for verification
        if dataset.startswith("mnist"):
            net = load_mnist_model(pruned_net_path, self.device)
        elif dataset.startswith("cifar10"):
            net = load_cifar10_model(pruned_net_path, self.device, model_type=dataset)
        elif dataset == "taxinet":
            net = load_taxinet_model(pruned_net_path, self.device)
        elif dataset == "gtsrb":
            net = load_gtsrb_model(pruned_net_path, self.device)
        delta_suff = f"frac{frac:.3f}_" if (frac is not None) else ""
        adv_x_path = os.path.join(self.plots_dir,f"{adv_prefix}_sample{sample_id}_delta{delta_suff}_{epsilon:.3f}.txt")

        try:
            _, ver_res = find_formal_adv_example(
                dataset=dataset,
                pruned_net_path=pruned_net_path,
                full_net_path=full_net_path,
                X=X,
                adv_x_path=adv_x_path,
                device=self.device,
                epsilon=epsilon,
                metric=metric,
                delta=delta,
                exp_paths=self.exp_paths
            )

            if ver_res['unsafe'] > 0: result = 'not-robust'
            elif ver_res['timeout'] > 0: result = 'timeout'
            else: result = 'robust'

            if result == 'not-robust':
                adv_x = load_x_from_file(dataset, adv_x_path)
                stats_fn = f"{adv_prefix}_STATS_{sample_id}_{(f'{frac:.2f}' if frac is not None else f'{delta:.2f}')}_{epsilon:.3f}.txt"
                print_logits_for_adv_x(dataset, net, full_net, X, adv_x, self.device, save_to_path=os.path.join(self.plots_dir, stats_fn))
            return result

        except Exception as e:
            logging.error(f"robustness check failed for sample {sample_id}: {e}")
            return 'error'


class ContrastiveMHSRunner(ExperimentRunner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.records = []
        self.csv_path = None

    def record(self, row):
        if self.csv_path is None:
            exp_dir_name = os.path.basename(self.experiment_dir.rstrip("/"))
            self.csv_path = os.path.join(self.plots_dir, f"{self.dataset}_{exp_dir_name}_contrastive_results.csv")
        self.records.append(row)
        pd.DataFrame([row]).to_csv(self.csv_path, mode='a', header=not os.path.exists(self.csv_path), index=False)

    def run(self, name, exp_type, dataset, parameters):
        self.set_up_logger()
        self.log_experiment_start(self.conf_file)
        self.dataset = dataset
        os.makedirs(self.exp_paths['contrastive_results_dir'], exist_ok=True)

        job_id = parameters.get('job_id')
        num_jobs = parameters.get('num_jobs')
        batches = parameters.get('batches', [[]])

        for i, batch in enumerate(batches):
            # --- Batch-level synchronization barrier ---
            if i > 0 and job_id is not None:
                prev_batch_id_str = "_".join(map(str, batches[i-1]))
                prev_batch_done_file = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{prev_batch_id_str}", '_batch_done')
                self.log(f"Job {job_id} waiting for previous batch {prev_batch_id_str} to complete...")
                while not os.path.exists(prev_batch_done_file):
                    time.sleep(10)
                self.log(f"Job {job_id} detected previous batch {prev_batch_id_str} is done. Proceeding to next batch.")

            self.process_batch(batch, parameters, job_id, num_jobs)

        self.log_experiment_end(self.conf_file)

    def process_batch(self, batch, parameters, job_id, num_jobs):
        is_aggregator = job_id is not None and job_id == 0
        single_job = num_jobs is None
        batch_start_time = time.time() if is_aggregator or single_job else None

        full_net, full_net_path, _, _, _, _, x, _ = load_dataset_and_model(self.dataset, self.device, self.exp_paths, sample_ids=batch)
        adv_x_path = self.exp_paths.get('adv_x_path')
        query_timeout = parameters.get('query_timeout', 45)
        batch_id_str = "_".join(map(str, batch))

        row = {
            'batch': batch_id_str,
            'epsilon': parameters.get('epsilon'),
            'patch_eps': parameters.get('patch_eps'),
            'delta': parameters.get('delta')
        }

        all_labels = ['singletons', 'pairs', 'triplets']
        # Initialize all potential columns to None
        for label in all_labels:
            row[f'{label}_mhs_size'] = None
            row[f'{label}_sufficiency'] = None
            row[f'{label}_timeouts'] = None

        prev_noncontrastives = set()
        current_union = set()

        # This stop signal is for the entire batch.
        batch_stop_signal_file = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{batch_id_str}", '_stop_batch')

        for i, subset_size in enumerate([1, 2, 3]):
            stage_label = all_labels[i]
            stage_dir = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{batch_id_str}", stage_label)
            os.makedirs(stage_dir, exist_ok=True)
            # This stop signal is for a specific stage, can be triggered by external factors (e.g. user cancellation).
            stage_stop_signal_file = os.path.join(stage_dir, '_stop_signal')
            stage_done_file = os.path.join(stage_dir, '_stage_done')

            # Path for the aggregator to save, and workers to load, the cumulative non-contrastive set.
            prev_stage_noncontrastive_file = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{batch_id_str}", 'cumulative_noncontrastive.pkl')

            # If a stop signal is already present for this batch, abort.
            if os.path.exists(batch_stop_signal_file):
                self.log(f"Batch stop signal found for batch {batch_id_str}. Aborting stage '{stage_label}'.")
                break

            # If a stop signal is already present for this stage, abort.
            if os.path.exists(stage_stop_signal_file):
                self.log(f"Stop signal found for stage '{stage_label}', batch {batch_id_str}. Aborting.")
                break  # <-- Exits the stage loop for this batch

            # --- WORKER LOGIC ---
            if job_id is not None and num_jobs is not None and job_id > 0:
                # Wait for the previous stage to be completed by the aggregator
                if i > 0:
                    prev_stage_done_file = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{batch_id_str}", all_labels[i-1], '_stage_done')
                    while not os.path.exists(prev_stage_done_file):
                        if os.path.exists(batch_stop_signal_file):
                            self.log(f"Worker {job_id} detected batch stop signal while waiting for stage {all_labels[i-1]}. Aborting batch.")
                            return
                        time.sleep(5)

                # This check is now outside the loop and will be executed by all workers for stages > 0,
                # resolving the race condition.
                if os.path.exists(batch_stop_signal_file):
                    self.log(f"Worker {job_id} detected batch stop signal for batch {batch_id_str} before starting stage '{stage_label}'. Aborting.")
                    return

            # --- ALL JOBS (WORKERS + AGGREGATOR) ---
            # Load the cumulative non-contrastive set from the previous stage.
            if i > 0 and os.path.exists(prev_stage_noncontrastive_file):
                with open(prev_stage_noncontrastive_file, 'rb') as f:
                    prev_noncontrastives = pickle.load(f)
                self.log(f"Job {job_id} loaded {len(prev_noncontrastives)} non-contrastive subsets from previous stages.")

            # Do the computation for the current stage
            contrastive, noncontrastive, timeouts = check_all_contrastive_subsets(
                subset_size=subset_size, dataset=self.dataset, net=full_net, x=x,
                full_net_path=full_net_path,
                noncontrastive_subsets_previous_sizes=prev_noncontrastives,
                device=self.device, metric=parameters['metric'],
                epsilon=parameters['epsilon'], delta=parameters['delta'],
                patch_eps=parameters['patch_eps'],
                exp_paths=self.exp_paths, adv_x_path=adv_x_path,
                query_timeout=query_timeout, verbose=False,
                job_id=job_id, num_jobs=num_jobs,
                stop_signal_file=stage_stop_signal_file
            )
            # Save partial results
            partial_result_path = os.path.join(stage_dir, f'results_job_{job_id}.pkl')
            with open(partial_result_path, 'wb') as f:
                pickle.dump({'contrastive': contrastive, 'noncontrastive': noncontrastive, 'timeouts': timeouts}, f)
            self.log(f"Job {job_id} finished computation for stage '{stage_label}', batch {batch_id_str}.")

            # --- AGGREGATOR ONLY LOGIC ---
            if is_aggregator:
                # Wait for all worker jobs to save their results for the current stage
                expected_files = num_jobs # Aggregator waits for all N files, including its own.
                while True:
                    files = glob.glob(os.path.join(stage_dir, 'results_job_*.pkl'))
                    if os.path.exists(stage_stop_signal_file):
                        self.log(f"Aggregator detected stop signal for stage '{stage_label}'. Aborting batch.")
                        # --- Create stage_done file so workers can exit gracefully ---
                        with open(stage_done_file, 'w') as f:
                            f.write('done')
                        self.log(f"Aggregator created done signal (due to stop): {stage_done_file}")
                        # Create a stop signal for the whole batch to notify workers.
                        with open(batch_stop_signal_file, 'w') as f:
                            f.write(f"Stop signal detected by aggregator at stage {stage_label}.")
                        self.log(f"Aggregator created batch stop signal due to stop file: {batch_stop_signal_file}")
                        break  # <-- Exits the wait loop for this stage
                    if len(files) >= expected_files:
                        self.log(f"Aggregator found all {len(files)} result files for stage '{stage_label}'.")
                        break
                    self.log(f"Aggregator waiting for workers on stage '{stage_label}'. Found {len(files)}/{expected_files} files.")
                    time.sleep(10)

                if os.path.exists(batch_stop_signal_file): break  # <-- Exits the stage loop for this batch

                # Aggregate results
                all_contrastive, all_noncontrastive, all_timeouts = set(), set(), set()
                files = glob.glob(os.path.join(stage_dir, 'results_job_*.pkl'))
                for f_path in files:
                    with open(f_path, 'rb') as f:
                        data = pickle.load(f)
                        all_contrastive.update(data['contrastive'])
                        all_noncontrastive.update(data['noncontrastive'])
                        all_timeouts.update(data['timeouts'])

                self.log(f"Aggregator aggregated {len(all_timeouts)} timeouts for stage '{stage_label}'.")

                logging.info(f"Aggregator aggregated results for stage '{stage_label}', batch {batch_id_str}. Contrastive: {all_contrastive}")
                logging.info(f"Aggregator aggregated results for stage '{stage_label}', batch {batch_id_str}. Non-contrastive: {all_noncontrastive}")

                # --- Process aggregated results and prepare for next stage ---
                prev_noncontrastives.update(all_noncontrastive)
                # Save the cumulative non-contrastive set for the next stage's workers
                with open(prev_stage_noncontrastive_file, 'wb') as f:
                    pickle.dump(prev_noncontrastives, f)
                self.log(f"Aggregator saved {len(prev_noncontrastives)} cumulative non-contrastive subsets.")

                current_union.update(all_contrastive)
                mhs = get_contrastives_mhs(current_union)
                mhs_names = get_neurons_by_names(mhs)

                mhs_sufficiency_status = verify_mhs_sufficiency(adv_x_path, self.dataset, parameters['delta'], self.device, parameters['epsilon'],
                    self.exp_paths, full_net_path, parameters['metric'], mhs, full_net, parameters['patch_eps'], query_timeout=120, x=x, verbose=1)

                row[f'{stage_label}_mhs_size'] = len(mhs)
                row[f'{stage_label}_sufficiency'] = mhs_sufficiency_status
                row[f'{stage_label}_timeouts'] = len(all_timeouts)
                self.log(f"Stage '{stage_label}' complete for batch {batch_id_str}. MHS size: {len(mhs)}. Sufficient: {mhs_sufficiency_status}")

                plot_path = os.path.join(self.plots_dir, f"mhs_{stage_label}_batch_{batch_id_str}.png")
                visualize_mlp_with_active_neurons(full_net, plot_path, {"active": mhs_names, "granularity": "neurons"})

                # Create the done file to signal workers to proceed
                with open(stage_done_file, 'w') as f: f.write('done')
                self.log(f"Aggregator created done signal: {stage_done_file}")

                if mhs_sufficiency_status == 'SUFFICIENT':
                    self.log(f"Sufficient MHS found at stage '{stage_label}'. Ending processing for batch {batch_id_str}.")
                    # Create a stop signal for the whole batch to notify workers.
                    with open(batch_stop_signal_file, 'w') as f:
                        f.write(f"Sufficient MHS found by aggregator at stage {stage_label}.")
                    self.log(f"Aggregator created batch stop signal: {batch_stop_signal_file}")
                    break

            # --- SINGLE JOB (NO SLURM) LOGIC ---
            elif single_job:
                self.log(f"Stage '{stage_label}' had {len(timeouts)} timeouts.")

                logging.info(f"stage '{stage_label}', batch {batch_id_str}. Contrastive: {contrastive}")
                logging.info(f"stage '{stage_label}', batch {batch_id_str}. Non-contrastive: {noncontrastive}")

                prev_noncontrastives.update(noncontrastive)
                current_union.update(contrastive)
                mhs = get_contrastives_mhs(current_union)
                mhs_names = get_neurons_by_names(mhs)

                mhs_sufficiency_status = verify_mhs_sufficiency(adv_x_path, self.dataset, parameters['delta'], self.device, parameters['epsilon'],
                    self.exp_paths, full_net_path, parameters['metric'], mhs, full_net, parameters['patch_eps'], query_timeout=120, x=x, verbose=1)

                row[f'{stage_label}_mhs_size'] = len(mhs)
                row[f'{stage_label}_sufficiency'] = mhs_sufficiency_status
                row[f'{stage_label}_timeouts'] = len(timeouts)
                self.log(f"Stage '{stage_label}' complete for batch {batch_id_str}. MHS size: {len(mhs)}. Sufficient: {mhs_sufficiency_status}")

                plot_path = os.path.join(self.plots_dir, f"mhs_{stage_label}_batch_{batch_id_str}.png")
                visualize_mlp_with_active_neurons(full_net, plot_path, {"active": mhs_names, "granularity": "neurons"})

                if mhs_sufficiency_status == 'SUFFICIENT':
                    self.log(f"Sufficient MHS found at stage '{stage_label}'. Ending processing for batch {batch_id_str}.")
                    break

        # --- Record results for the batch (only aggregator or single job does this) ---
        if is_aggregator or single_job:
            row['time'] = time.time() - batch_start_time
            self.record(row)
            # Aggregator signals that this batch is complete for all other jobs.
            if is_aggregator:
                batch_done_file = os.path.join(self.exp_paths['contrastive_results_dir'], f"batch_{batch_id_str}", '_batch_done')
                with open(batch_done_file, 'w') as f:
                    f.write('done')
                self.log(f"Aggregator created batch done signal: {batch_done_file}")


def get_runner(exp_type: str) -> ExperimentRunner:
    runners = {
        'formal_pruning': FormalPruningRunner,
        'informal_pruning': InformalPruningRunner,
        'adversarial_example': AdversarialExample,
        'collect_circuits': CollectCircuits,
        'collect_patching': CollectPatching,
        'contrastive_mhs': ContrastiveMHSRunner,
    }
    return runners.get(exp_type)
