import copy
import sys
from curses import meta
from functools import partial
from typing import Any, List

import numpy
import torch
import torch.nn.functional as F

from sklearn import ensemble
from tensordict import LazyStackedTensorDict, pad_sequence, TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch.func import functional_call, stack_module_state

# sys.path.append("..")
# from ..data.pyaig.aig_env import AIGEnv
from torchrl.envs import EnvBase
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.modules import (
    ActorValueOperator,
    QValueModule,
    TanhModule,
    ValueOperator,
    VmapModule,
)

from rl import mcts_policy_new
from .mcts_node import MCTSNode
from .mcts_policy import (
    ActionExplorationModule,
    AlphaZeroConfig,
    AlphaZeroExpansionStrategy,
    DirichletNoiseModule,
    MctsPolicy,
    PuctSelectionPolicy,
    SimulatedSearchPolicy,
    UpdateTreeStrategy,
)


class FineTuningCollate(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.dtype = torch.float32
        self.keys = ["action_dist", "action_mask", "causal_mask", "reward"]
        # self.keys = ["action_dist", "action_mask", "causal_mask"]
        self.negation_prob = 0.5
        self.const_node = True

    def __call__(self, batch: List[TensorDict]) -> TensorDict:
        batch = [td.clone() for td in batch]
        seq_lens = [td["nodes"].shape[-2] for td in batch]
        max_nodes = max(seq_lens)
        new_batch = torch.stack([self.prepare_sample(td, max_nodes) for td in batch])
        nodes = self.batch_nodes(batch)
        new_batch["nodes"] = nodes["nodes"].to(self.dtype)
        new_batch["attention_mask"] = nodes["attention_mask"]

        if self.device.type == "cuda":
            new_batch = new_batch.pin_memory()

        return new_batch

    def batch_nodes(self, batch: List[TensorDict]) -> TensorDict:
        for td in batch:
            td["nodes"] = torch.cat((td["target"], td["nodes"]), dim=-2)
        padded_batch = pad_sequence(
            [td.select("nodes") for td in batch], return_mask=True
        )
        padded_batch.rename_key_(("masks", "nodes"), "attention_mask")
        return padded_batch

    def prepare_sample(self, td, max_len):
        sql = td["nodes"].shape[-2]
        num_inputs = torch.sym_int(td["num_inputs"])

        td["action_mask"] = F.pad(
            td["action_mask"].view(4, sql, sql),
            (0, max_len - sql, 0, max_len - sql),
            value=0,
        ).view(-1)

        td["action_dist"] = F.pad(
            td["action_dist"].view(4, sql, sql),
            (0, max_len - sql, 0, max_len - sql),
            value=0,
        ).view(-1)

        td["causal_mask"] = self.create_causal_mask(
            num_inputs,  # type: ignore
            1,  # type: ignore
            max_len,
        )

        # # Apply negation transformation
        # if numpy.random.rand() < self.negation_prob:
        #     td["target"] = ~td["target"]

        # td["nodes"] = torch.cat((td["target"], td["nodes"]), dim=0)

        # # Apply the transform to the AIG
        # if numpy.random.rand() < self.negation_prob:
        #     self.apply_rand_perm(td)

        # td["nodes"] = td["nodes"].to(self.dtype)

        return td.select(*self.keys)

    def set_device(self, device):
        self.device = device

    def create_causal_mask(self, num_inputs: int, num_ouputs: int, num_nodes: int):
        diag = num_inputs + num_ouputs
        if not self.const_node:
            diag -= 1
        mask = torch.ones(
            (num_nodes + num_ouputs, num_nodes + num_ouputs), dtype=torch.bool
        ).tril_(diag)[None, :, :]
        return mask


def get_observation(nodes: torch.Tensor, target: torch.Tensor):
    return torch.cat((target, nodes), dim=-2).to(torch.float32)[None, :, :]


def get_observation_legacy(nodes: torch.Tensor, target: torch.Tensor):
    return torch.cat((nodes, target), dim=-2).to(torch.float32)


def get_attention_mask(observation, src_mask, num_inputs):
    causal_mask = torch.ones(
        (observation.shape[1], observation.shape[1]),
        dtype=torch.bool,
        device=observation.device,
        requires_grad=False,
    ).tril(
        diagonal=num_inputs
    )  # diagonal inputs has to n_inputs
    mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
    mask[:, :, :, -1] = True
    mask &= src_mask[:, None, None, :]
    mask = torch.zeros_like(
        mask, dtype=torch.float32, device=observation.device
    ).masked_fill(~mask, torch.finfo(torch.float32).min)
    return mask


def set_loss_mask(td: TensorDict) -> TensorDict:
    td["loss_mask"] = (
        (
            torch.triu(
                torch.ones(
                    (len(td["nodes"]), len(td["nodes"])),
                    dtype=torch.bool,
                    device=td.device,
                ),
                diagonal=0,
            ).T
        )
        .repeat(4, 1, 1)
        .view(-1)
    )
    return td


causal_mask_module = TensorDictModule(
    module=lambda x, y: get_causal_mask(x, y),
    in_keys=["observation", "num_inputs"],
    out_keys=["causal_mask"],
)


loop_causal_mask_module = TensorDictModule(
    module=lambda x, y: loop_causal_mask(x, y),
    in_keys=["observation", "num_inputs"],
    out_keys=["causal_mask"],
)

vm_causal_mask_module = VmapModule(causal_mask_module, 0)
vm_loop_causal_mask_module = VmapModule(loop_causal_mask_module, 0)


def loop_causal_mask(observations, num_inputs):
    bsz = num_inputs.shape[0]
    causal_masks = []
    for i in range(bsz):
        causal_masks.append(get_causal_mask(observations[i, :, :], num_inputs[i]))
    return torch.cat(causal_masks, dim=0)


def get_causal_mask_legacy(observation, num_inputs):
    causal_mask = torch.ones(
        (observation.shape[-2], observation.shape[-2]),
        dtype=torch.bool,
        device=observation.device,
        requires_grad=False,
    ).tril(
        diagonal=torch.sym_int(num_inputs)  # type: ignore
    )  # type: ignore | diagonal inputs has to n_inputs
    mask = causal_mask[None, None, :, :]
    mask[:, :, :, -1] = True
    # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)

    return mask


def get_causal_mask(observation, num_inputs, num_outputs=1, const_node=True):
    diag = num_inputs + num_outputs
    if not const_node:
        diag -= 1
    causal_mask = torch.ones(
        (observation.shape[-2], observation.shape[-2]),
        dtype=torch.bool,
        device=observation.device,
        requires_grad=False,
    ).tril(
        diagonal=torch.sym_int(diag)  # type: ignore
    )  # diagonal inputs has to n_inputs
    mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
    # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)
    return mask


def get_actor_value_model(model: torch.nn.Module, const_node=True):
    def get_causal_mask(observation, num_inputs, num_outputs=1, const_node=True):
        diag = num_inputs + num_outputs
        if not const_node:
            diag -= 1
        causal_mask = torch.ones(
            (observation.shape[-2], observation.shape[-2]),
            dtype=torch.bool,
            device=observation.device,
            requires_grad=False,
        ).tril(
            diagonal=torch.sym_int(diag)  # type: ignore
        )  # diagonal inputs has to n_inputs
        mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
        # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)
        return mask

    def combine_masks(
        causal_mask: torch.Tensor | None, attention_mask: torch.Tensor | None
    ) -> torch.Tensor | None:
        mask = causal_mask
        if attention_mask is not None:
            mask = (
                attention_mask[:, None, None, :]
                if mask is None
                else mask & attention_mask[:, None, None, :].to(torch.bool)
            )
        return mask

    partial_get_causal_mask = partial(get_causal_mask, const_node=const_node)

    causal_mask_module = TensorDictModule(
        module=lambda x, y, z: partial_get_causal_mask(x, y, z),
        in_keys=["observation", "num_inputs", "num_outputs"],
        out_keys=["causal_mask"],
    )

    observation_module = TensorDictModule(
        lambda x, y: get_observation(x, y),
        in_keys=["nodes", "target"],
        out_keys=["observation"],
    )

    hidden_module = TensorDictModule(
        module=model.get_hidden_module(),
        in_keys=["observation", "causal_mask", "num_outputs"],
        out_keys=["hidden"],
    )

    wrapped_hidden_module = TensorDictSequential(
        *[observation_module, causal_mask_module, hidden_module]
    )

    policy_module = TensorDictModule(
        module=model.get_policy_head(),
        in_keys=["hidden", "causal_mask", "num_outputs"],
        out_keys=["action_logits"],
    )

    softmax_wrapper = TensorDictModule(
        lambda x: torch.softmax(x.view(-1), dim=-1),
        in_keys=["action_logits"],
        out_keys=["action_value"],
    )

    qvalue_wrapper = QValueModule(
        action_space="categorical",
        action_mask_key="action_mask",
    )

    wrapped_policy_module = TensorDictSequential(
        *[
            policy_module,
            softmax_wrapper,
            qvalue_wrapper,
        ]
    )

    tanh_wrapper = TanhModule(
        in_keys=["state_value"],
        low=-1.0,
        high=1.0,
    )
    value_module = ValueOperator(
        module=model.get_value_head(),
        in_keys=["hidden"],
        out_keys=["state_value"],
    )

    wrapped_value_module = TensorDictSequential(*[value_module, tanh_wrapper])

    actor_value_agent = ActorValueOperator(
        wrapped_hidden_module, wrapped_policy_module, wrapped_value_module
    )

    return actor_value_agent


def get_actor_value_model_new(model: torch.nn.Module, const_node=True):
    def get_causal_mask(observation, num_inputs, num_outputs=1, const_node=True):
        diag = num_inputs + num_outputs
        if not const_node:
            diag -= 1
        causal_mask = torch.ones(
            (observation.shape[-2], observation.shape[-2]),
            dtype=torch.bool,
            device=observation.device,
            requires_grad=False,
        ).tril(
            diagonal=torch.sym_int(diag)  # type: ignore
        )  # diagonal inputs has to n_inputs
        mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
        # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)
        return mask

    partial_get_causal_mask = partial(get_causal_mask, const_node=const_node)

    causal_mask_module = TensorDictModule(
        module=lambda x, y, z: partial_get_causal_mask(x, y, z),
        in_keys=["observation", "num_inputs", "num_outputs"],
        out_keys=["causal_mask"],
    )

    observation_module = TensorDictModule(
        lambda x, y: get_observation(x, y),
        in_keys=["nodes", "target"],
        out_keys=["observation"],
    )

    hidden_module = TensorDictModule(
        module=model.get_hidden_module(),
        in_keys=["observation", "causal_mask"],
        out_keys=["hidden"],
    )

    wrapped_hidden_module = TensorDictSequential(
        *[observation_module, causal_mask_module, hidden_module]
    )

    policy_model = model.get_policy_head()
    meta_policy_module = copy.deepcopy(policy_model.policy_layers[0]).to("meta")

    params, buffs = stack_module_state(policy_model.policy_layers)

    def fmodel(params, buffers, x, attn_mask, num_outputs):
        return functional_call(
            meta_policy_module, (params, buffers), (x, attn_mask, num_outputs)
        )

    ensemble_policy_module = TensorDictModule(
        module=lambda x, y, z: torch.vmap(
            fmodel, in_dims=(0, 0, None, None, None), out_dims=1
        )(params, buffs, x, y, z),
        in_keys=["hidden", "causal_mask", "num_outputs"],
        out_keys=["action_logits"],
    )

    softmax_wrapper = TensorDictModule(
        lambda x: torch.softmax(x.view(-1), dim=-1),
        in_keys=["action_logits"],
        out_keys=["action_value"],
    )

    qvalue_wrapper = QValueModule(
        action_space="categorical",
        action_mask_key="action_mask",
    )

    wrapped_policy_module = TensorDictSequential(
        *[
            ensemble_policy_module,
            softmax_wrapper,
            qvalue_wrapper,
        ]
    )

    tanh_wrapper = TanhModule(
        in_keys=["state_value"],
        low=-1.0,
        high=1.0,
    )
    value_module = ValueOperator(
        module=model.get_value_head(),
        in_keys=["hidden"],
        # in_keys=["hidden", "num_outputs"],
        out_keys=["state_value"],
    )

    wrapped_value_module = TensorDictSequential(*[value_module, tanh_wrapper])

    actor_value_agent = ActorValueOperator(
        wrapped_hidden_module, wrapped_policy_module, wrapped_value_module
    )

    return actor_value_agent


def get_actor_value_model_legacy(model: torch.nn.Module, const_node=True):
    def get_causal_mask(observation, num_inputs):
        causal_mask = torch.ones(
            (observation.shape[-2], observation.shape[-2]),
            dtype=torch.bool,
            device=observation.device,
            requires_grad=False,
        ).tril(
            diagonal=torch.sym_int(num_inputs)  # type: ignore
        )  # diagonal inputs has to n_inputs
        mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
        mask[:, :, :, -1] = True
        # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)
        return mask

    if const_node:
        causal_mask_module = TensorDictModule(
            module=lambda x, y: get_causal_mask(x, y),
            in_keys=["observation", "num_inputs"],
            out_keys=["causal_mask"],
        )
    else:
        causal_mask_module = TensorDictModule(
            module=lambda x, y: get_causal_mask(x, y - 1),
            in_keys=["observation", "num_inputs"],
            out_keys=["causal_mask"],
        )

    observation_module = TensorDictModule(
        lambda x, y: torch.cat((x, y), dim=0).unsqueeze(0).to(torch.float32),
        in_keys=["nodes", "target"],
        out_keys=["observation"],
    )

    hidden_module = TensorDictModule(
        module=model.get_hidden_module(),
        in_keys=["observation", "causal_mask"],
        out_keys=["hidden"],
    )

    wrapped_hidden_module = TensorDictSequential(
        *[observation_module, causal_mask_module, hidden_module]
    )

    policy_module = TensorDictModule(
        module=model.get_policy_head(),
        in_keys=["hidden", "causal_mask"],
        out_keys=["action_logits"],
    )

    softmax_wrapper = TensorDictModule(
        lambda x: torch.softmax(x.view(-1), dim=-1),
        in_keys=["action_logits"],
        out_keys=["action_value"],
    )

    qvalue_wrapper = QValueModule(
        action_space="categorical",
        action_mask_key="action_mask",
    )

    wrapped_policy_module = TensorDictSequential(
        *[
            policy_module,
            softmax_wrapper,
            qvalue_wrapper,
        ]
    )

    tanh_wrapper = TanhModule(
        in_keys=["state_value"],
        low=0.0,
        high=1.0,
    )
    value_module = ValueOperator(
        module=model.get_value_head(), in_keys=["hidden"], out_keys=["state_value"]
    )

    wrapped_value_module = TensorDictSequential(*[value_module, tanh_wrapper])

    actor_value_agent = ActorValueOperator(
        wrapped_hidden_module, wrapped_policy_module, wrapped_value_module
    )

    return actor_value_agent


def make_alpha_zero_actor(model, const_node=True):
    def get_causal_mask(observation, num_inputs):
        causal_mask = torch.ones(
            (observation.shape[1], observation.shape[1]),
            dtype=torch.bool,
            device=observation.device,
            requires_grad=False,
        ).tril(
            diagonal=torch.sym_int(num_inputs)  # type: ignore
        )  # type: ignore | diagonal inputs has to n_inputs
        mask = causal_mask[None, None, :, :].repeat(observation.shape[0], 1, 1, 1)
        mask[:, :, :, -1] = True
        # mask = torch.zeros_like(mask, dtype=torch.float32, device=observation.device).masked_fill(~mask, torch.finfo(torch.float32).min)
        return mask

    if const_node:
        causal_mask_module = TensorDictModule(
            module=lambda x, y: get_causal_mask(x, y),
            in_keys=["observation", "num_inputs"],
            out_keys=["causal_mask"],
        )
    else:
        causal_mask_module = TensorDictModule(
            module=lambda x, y: get_causal_mask(x, y - 1),
            in_keys=["observation", "num_inputs"],
            out_keys=["causal_mask"],
        )

    observation_module = TensorDictModule(
        lambda x, y: torch.cat((x, y), dim=0).unsqueeze(0).to(torch.float32),
        in_keys=["nodes", "target"],
        out_keys=["observation"],
    )

    hidden_module = TensorDictModule(
        module=model.get_hidden_module(),
        in_keys=["observation", "causal_mask"],
        out_keys=["hidden"],
    )

    wrapped_hidden_module = TensorDictSequential(
        *[observation_module, causal_mask_module, hidden_module]
    )

    policy_module = TensorDictModule(
        module=model.get_policy_head(),
        in_keys=["hidden", "causal_mask"],
        out_keys=["action_logits"],
    )

    softmax_wrapper = TensorDictModule(
        lambda x: torch.softmax(x.view(-1), dim=-1),
        in_keys=["action_logits"],
        out_keys=["action_value"],
    )

    qvalue_wrapper = QValueModule(
        action_space="categorical",
        action_mask_key="action_mask",
    )

    wrapped_policy_module = TensorDictSequential(
        *[
            policy_module,
            softmax_wrapper,
            qvalue_wrapper,
        ]
    )

    tanh_wrapper = TanhModule(
        in_keys=["state_value"],
        low=0.0,
        high=1.0,
    )
    value_module = ValueOperator(
        module=model.get_value_head(), in_keys=["hidden"], out_keys=["state_value"]
    )

    wrapped_value_module = TensorDictSequential(*[value_module, tanh_wrapper])

    actor_value_agent = ActorValueOperator(
        wrapped_hidden_module, wrapped_policy_module, wrapped_value_module
    )

    return actor_value_agent


def make_train_actor(model, const_node=True):
    # causal_mask_module = TensorDictModule(
    #     module=lambda x, y: get_causal_mask(x, y),
    #     in_keys=["observation", "num_inputs"],
    #     out_keys=["causal_mask"]
    # )

    # observation_module = TensorDictModule(
    #     lambda x, y: torch.cat((x, y), dim=0).unsqueeze(0),
    #     in_keys=["nodes", "target"],
    #     out_keys=["observation"]
    # )

    hidden_module = TensorDictModule(
        module=model.get_hidden_module(),
        in_keys=["observation", "attention_mask"],
        out_keys=["hidden"],
    )

    wrapped_hidden_module = TensorDictSequential(
        *[
            # observation_module,
            # causal_mask_module,
            hidden_module
        ]
    )

    policy_module = TensorDictModule(
        module=model.get_policy_head(),
        in_keys=["hidden", "attention_mask"],
        out_keys=["action_logits"],
    )

    # softmax_wrapper = TensorDictModule(
    #     lambda x: torch.softmax(x.view(-1), dim=-1),
    #     in_keys=["action_logits"],
    #     out_keys=["action_value"]
    # )

    flatten_module = TensorDictModule(
        module=lambda x: x.flatten(start_dim=1),
        in_keys=["action_logits"],
        out_keys=["action_logits"],
    )

    qvalue_wrapper = QValueModule(
        action_space="categorical",
        action_value_key="action_logits",
        action_mask_key="action_mask",
    )

    wrapped_policy_module = TensorDictSequential(
        *[
            policy_module,
            flatten_module,
            # softmax_wrapper,
            qvalue_wrapper,
        ]
    )

    tanh_wrapper = TanhModule(
        in_keys=["state_value"],
        low=0.0,
        high=1.0,
    )
    value_module = ValueOperator(
        module=model.get_value_head(),
        in_keys=["hidden", ("masks", "observation")],
        out_keys=["state_value"],
    )

    wrapped_value_module = TensorDictSequential(*[value_module, tanh_wrapper])

    actor_value_agent = ActorValueOperator(
        wrapped_hidden_module, wrapped_policy_module, wrapped_value_module
    )

    return actor_value_agent


def generate_AIG(
    model: torch.nn.Module,
    aig_env: EnvBase,
    max_nodes: int,
    cfg: AlphaZeroConfig,
) -> bool:
    actor_value_agent = get_actor_value_model(model)

    # Initiate Policy Module
    tree_strategy = UpdateTreeStrategy(
        value_network=actor_value_agent.get_value_operator(),
        use_value_network=cfg.use_value_network,
    )

    expansion_strategy = AlphaZeroExpansionStrategy(
        policy_module=actor_value_agent.get_policy_operator(),
    )

    selection_strategy = PuctSelectionPolicy(cfg.c_puct)

    exploration_strategy = ActionExplorationModule()

    mcts_policy = MctsPolicy(
        expansion_strategy=expansion_strategy,
        selection_strategy=selection_strategy,
        exploration_strategy=exploration_strategy,
    )

    noise_module = None
    if cfg.inject_noise and cfg.dirichlet_alpha is not None:
        noise_module = DirichletNoiseModule(cfg.dirichlet_alpha)

    policy = SimulatedSearchPolicy(
        policy=mcts_policy,
        tree_updater=tree_strategy,
        env=aig_env,
        num_simulations=cfg.num_simulations,
        simulation_max_steps=cfg.simulation_max_steps,
        max_steps=max_nodes,
        noise_module=noise_module,
        reutilize_tree=cfg.reutilize_tree,
    )

    with torch.no_grad():
        rollout = aig_env.rollout(
            policy=policy, max_steps=max_nodes, return_contiguous=False
        )
        return rollout["next", "terminated"][-1]  # type: ignore


def generate_AIG_greedy(
    model: torch.nn.Module,
    aig_env: EnvBase,
    max_nodes: int,
    cfg: AlphaZeroConfig,
) -> bool:
    actor_value_agent = get_actor_value_model(model)

    expansion_strategy = AlphaZeroExpansionStrategy(
        policy_module=actor_value_agent.get_policy_operator(),
    )

    selection_strategy = PuctSelectionPolicy(cfg.c_puct)

    exploration_strategy = ActionExplorationModule()

    root = MCTSNode.root()
    mcts_policy = MctsPolicy(
        expansion_strategy=expansion_strategy,
        selection_strategy=selection_strategy,
        exploration_strategy=exploration_strategy,
    )
    mcts_policy.set_node(root)

    with torch.no_grad() and set_exploration_type(ExplorationType.RANDOM):
        rollout = aig_env.rollout(
            policy=mcts_policy, max_steps=max_nodes, return_contiguous=False
        )
        return rollout["next", "terminated"][-1]  # type: ignore


def generate_AIG_new(
    model: torch.nn.Module,
    aig_env: EnvBase,
    max_nodes: int,
    cfg: AlphaZeroConfig,
) -> bool:
    actor_value_agent = get_actor_value_model(model)

    # Initiate Policy Module
    tree_strategy = mcts_policy_new.UpdateTreeStrategy(
        value_network=actor_value_agent.get_value_operator(),
        use_value_network=cfg.use_value_network,
    )

    expansion_strategy = mcts_policy_new.AlphaZeroExpansionStrategy(
        policy_module=actor_value_agent.get_policy_operator(),
    )

    selection_strategy = mcts_policy_new.PUCTSelectionPolicy(cfg.c_puct)

    exploration_strategy = mcts_policy_new.ActionExplorationModule()

    mcts_policy = mcts_policy_new.MCTSPolicy(
        expansion_strategy=expansion_strategy,
        selection_strategy=selection_strategy,
        exploration_strategy=exploration_strategy,
    )

    noise_module = None
    if cfg.dirichlet_alpha is not None:
        noise_module = mcts_policy_new.DirichletNoiseModule(cfg.dirichlet_alpha)

    policy = mcts_policy_new.SimulatedSearchPolicy(
        policy=mcts_policy,
        tree_updater=tree_strategy,
        env=aig_env,
        num_simulations=cfg.num_simulations,
        simulation_max_steps=cfg.simulation_max_steps,
        max_steps=max_nodes,
        noise_module=noise_module,
    )

    with torch.no_grad():
        rollout = aig_env.rollout(
            policy=policy, max_steps=max_nodes, return_contiguous=False
        )
        return rollout["next", "terminated"][-1]  # type: ignore
