"""
search_mp.py
~~~~~~~~~~~~
Top-level helper so that multiprocessing can pickle it.
Each worker calls the *original* Search._test_one_batch
with batch_size == 1, and obeys the per-instance max_runtime.
"""

import copy
import time
import numpy as np

# Function signature kept positional for simpler pickling
def _worker_test_one_batch(
        env_params,
        tester_params,
        deconstruction_state_dicts,
        episode_idx,
        seed=None
    ):
    # Heavy imports stay inside the subprocess so the main process
    # doesn’t pay the cost until the pool is launched
    from .search_sa_agent import Search
    from .model  import Model
    import torch, random

    # Optional deterministic seed per worker
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    # Re-instantiate Search (this also reloads pybind SISRs)
    search = Search(copy.deepcopy(env_params), copy.deepcopy(tester_params))


    # Re-hydrate deconstruction models, if any
    if deconstruction_state_dicts:
        search.deconstruction_operator = []
        for sd in deconstruction_state_dicts:
            mdl = Model(**sd["cfg"])
            mdl.load_state_dict(sd["weights"])
            mdl.eval()
            search.deconstruction_operator.append({"model": mdl})

    # assert False, search.env

    # One instance per worker ⇒ batch_size == 1
    costs, rt, nb_iter, solutions = search._test_one_batch(
        batch_size   = 1,
        nb_iterations= tester_params['nb_iterations'],
        episode      = episode_idx,
        aug_factor   = tester_params['aug_factor']
    )

    # Collapse scalars before sending back
    return float(costs[0]), float(rt), int(nb_iter), solutions[0]
