import torch
from typing import Tuple, List

from tsp_env import TSPEnvironment
from tsp_policy import TSPPolicy
from tsp_policy_two_stage import TSPTwoStagePolicy


@torch.no_grad()
def solve_tsp_with_env_policy(
    model: TSP_net,
    x: torch.Tensor,
    action_k: int,
    state_k: List[int],
    deterministic: bool = False,
    if_use_local_mask: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Roll out a full tour using the separated environment and policy.

    Returns (tours, sum_log_probs) matching the outputs of TSP_net.forward.
    """
    env = TSPEnvironment(x)
    policy = TSPPolicy(model)

    log_probs = []
    obs = env.observation()
    # Run for nb_nodes-1 steps after start
    for _ in range(x.size(1) - 1):
        action, log_prob, _ = policy.select_action(
            obs, action_k=action_k, state_k=state_k, deterministic=deterministic, if_use_local_mask=if_use_local_mask
        )
        obs, done = env.step(action)
        log_probs.append(log_prob)
        if done:
            break

    tours = env.get_tour_tensor()
    sum_log_probs = torch.stack(log_probs, dim=1).sum(dim=1) if log_probs else torch.zeros((x.size(0),), device=x.device)
    return tours, sum_log_probs


@torch.no_grad()
def solve_tsp_with_two_stage_policy(
    args_stage1,
    args_stage2,
    x: torch.Tensor,
    k_promising: int,
):
    """Roll out using the two-stage policy.

    Stage 1: encode all unvisited nodes and sample k promising actions.
    Stage 2: re-score those k with the stage-2 network and pick the best.
    """
    env = TSPEnvironment(x)
    policy = TSPTwoStagePolicy(args_stage1)

    log_probs = []
    obs = env.observation()
    for _ in range(x.size(1) - 1):
        action, log_prob, _ = policy.select_action(obs, k_promising=k_promising)
        obs, done = env.step(action)
        log_probs.append(log_prob)
        if done:
            break

    tours = env.get_tour_tensor()
    sum_log_probs = torch.stack(log_probs, dim=1).sum(dim=1) if log_probs else torch.zeros((x.size(0),), device=x.device)
    return tours, sum_log_probs
