import torch
from torch.distributions.categorical import Categorical
from torch_cluster import knn
from typing import Optional, Dict, Any, Tuple, List

from utils.utils_for_model import create_distance_mask_for_knn
from TSP_net import TSP_net


class TSPPolicy:
    """Policy that selects the next node for TSP in one step using a TSP_net model.

    This mirrors the per-step logic inside TSP_net.forward but does not manage environment state.
    """

    def __init__(self, model: TSP_net):
        self.model = model

    @torch.no_grad()
    def select_action(
        self,
        obs: Dict[str, torch.Tensor],
        action_k: int,
        state_k: List[int],
        deterministic: bool = False,
        if_use_local_mask: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        """Select next global node indices for each instance in batch.

        Returns (action_global_idx, log_prob, info)
        - action_global_idx: (bsz,) long tensor of chosen node indices in original graph space
        - log_prob: (bsz,) log probability of chosen actions
        - info: diagnostics including candidate indices and probabilities
        """
        x: torch.Tensor = obs["x"]
        last_visited_node: torch.Tensor = obs["last_visited_node"]
        first_visited_node: torch.Tensor = obs["first_visited_node"]
        mask_global: torch.Tensor = obs["mask_global"]

        bsz = x.shape[0]
        nb_nodes = x.shape[1]
        zero_to_bsz = torch.arange(bsz, device=x.device)

        all_idx = torch.arange(0, nb_nodes, device=x.device).repeat((bsz, 1))
        # unvisited/available nodes per instance
        unvisited_matrix = torch.reshape(all_idx[mask_global], (bsz, -1))
        num_nodes = unvisited_matrix.size(1)

        # gather subgraph of available nodes
        b_graph = torch.arange(0, bsz, device=x.device).repeat(num_nodes).sort()[0]
        unvisited_matrix_idx = unvisited_matrix.view((-1,))
        graph = x[b_graph, unvisited_matrix_idx]
        graph = graph.view((bsz, -1, self.model.dim_input))

        k_action = min(action_k, num_nodes)
        k_state = min(max(state_k), num_nodes) if self.model.num_state_encoder > 0 else k_action

        graph_for_knn = graph.view((-1, self.model.dim_input))
        last_visited_node_for_knn = last_visited_node.view((-1, self.model.dim_input))
        knn_output = knn(graph_for_knn, last_visited_node_for_knn, k_state, b_graph, zero_to_bsz)
        knn_idx = knn_output[1, :] % num_nodes
        knn_idx = knn_idx.view((bsz, k_state)).contiguous()

        # action encoder
        action_idx = knn_idx[:, :k_action].contiguous()
        action_mask: Optional[torch.Tensor] = None
        if if_use_local_mask:
            action_mask = create_distance_mask_for_knn(last_visited_node, action_idx, graph)
        emb_action = self.model.action_encoder(graph, action_idx, last_visited_node, mask=action_mask)
        emb_q = emb_action[:, k_action:(k_action + 1), :]
        emb_other = emb_action[:, :k_action, :]

        # state encoder(s)
        for i in range(self.model.num_state_encoder):
            temp_k = min(state_k[i], num_nodes)
            temp_idx = knn_idx[:, :temp_k].contiguous()
            emb_state = self.model.state_encoders[i](graph, temp_idx, last_visited_node, first_visited_node)
            emb_q = torch.cat((emb_q, emb_state[:, temp_k:(temp_k + 1), :]), dim=2)
            emb_q = torch.cat((emb_q, emb_state[:, (temp_k + 1):(temp_k + 2), :]), dim=2)
            emb_other = torch.cat((emb_other, emb_state[:, :k_action, :]), dim=2)

        # map candidate indices back to original graph indices
        action_idx_for_ref = action_idx.view((bsz * k_action,))
        b_action = torch.arange(0, bsz, device=x.device).repeat(k_action).sort()[0]
        candidate_global_idx = unvisited_matrix[b_action, action_idx_for_ref].view(bsz, -1)  # (bsz, k_action)

        mask_for_decoder = action_mask.bool() if action_mask is not None else None

        # decode one step
        h_q = self.model.query_mlp(emb_q)
        K_att_decoder = self.model.WK_att_decoder(emb_other)
        V_att_decoder = self.model.WV_att_decoder(emb_other)
        prob_next_node = self.model.decoder(h_q, K_att_decoder, V_att_decoder, mask_for_decoder)

        # pick action
        if deterministic:
            idx = torch.argmax(prob_next_node, dim=1)
        else:
            idx = Categorical(prob_next_node).sample()

        chosen_global_idx = candidate_global_idx[zero_to_bsz, idx]
        prob_choice = prob_next_node[zero_to_bsz, idx]
        log_prob = torch.log(prob_choice)

        info: Dict[str, Any] = {
            "candidate_global_idx": candidate_global_idx,
            "candidate_probs": prob_next_node,
        }

        return chosen_global_idx, log_prob, info
