import torch
import torch.nn as nn
import torch.nn.functional as F

from src.pe.base import TourLayer


class TourLayer_APE(TourLayer):


    def __init__(self, **model_params):
        super().__init__(**model_params)
        embedding_dim = self.embedding_dim

        # We keep this simple & robust by picking a generous maximum and clamping at runtime.
        self.max_tour_positions = int(model_params.get("max_tour_positions", 256))
        self.abs_pos_embed = nn.Embedding(self.max_tour_positions, embedding_dim)
        self.angle_embed = nn.Linear(2, embedding_dim)
        
        # Replace the 2D->D combiner with 4D->D (keep the name to minimize code changes elsewhere)
        self.tour_combiner = nn.Linear(embedding_dim * 4, embedding_dim)

    def forward(
        self,
        batch_size: int,
        num_customers: int,
        tour_index: torch.Tensor,   # shape: [B, N], values in {0..T} or -1 for unvisited
        out: torch.Tensor,           # shape: [B, N+1, D], index 0 is depot
        pos_index: torch.Tensor,     # shape: [B, N], values in {0..T} or -1 for unvisited
        cur_dist: torch.Tensor,      # shape: [B, N], values in {0..T} or -1 for unvisited
        tour_angle: torch.Tensor     # shape: [B, N], values in {0..T} or -1 for unvisited
    ) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Find out if there exist unvisited customers (possible for example in the PCVRP). If yes,
        # create a dummy tour for them with index max_tour_index + 1
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index += 1
            all_nodes_visited = False

        # SECTION: >>> PE start >>>

        # Use the provided pos_index directly for absolute positional embedding
        # Clamp pos_index to valid range for the embedding table
        clamped_pos_index = pos_index.clamp_min(0).clamp_max(self.max_tour_positions - 1)
        pos_emb = self.abs_pos_embed(clamped_pos_index)  # [B, N, D]

        # If a dummy tour exists, do not inject PE for those customers (keep them neutral as before).
        if not all_nodes_visited:
            dummy_id = max_tour_index  # last id is the dummy tour
            mask_dummy = (tour_index == dummy_id).unsqueeze(-1)  # [B, N, 1]
            pos_emb = pos_emb.masked_fill(mask_dummy, 0.0)

        pos_emb = pos_emb.to(dtype=customer_embeddings.dtype)

        sin_cos = torch.stack([torch.sin(tour_angle), torch.cos(tour_angle)], dim=-1)
        angle_emb = self.angle_embed(sin_cos)

        # SECTION: <<< PE end <<<

        # Initialize tour embeddings
        tour_embeddings = torch.zeros(batch_size, max_tour_index + 1, customer_embeddings.shape[2],
                                      dtype=customer_embeddings.dtype)

        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(1, tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
                                     customer_embeddings)

        # If a dummy tour was used for unvisited customers, set the tour embeddings for the dummy to all zero
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather the customer tour embeddings based on tour_index
        customer_tour_embeddings = torch.gather(tour_embeddings, 1,
                                                tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]))

        # SECTION: >>> PE Embedding Start >>>

        combined_embeddings = torch.cat(
            (customer_embeddings, customer_tour_embeddings, pos_emb, angle_emb), dim=2
        )
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 4)

        # SECTION: <<< PE Embedding End <<<

        # Apply the tour combiner layer and activation
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))

        # Apply feedforward layer
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Normalize and add embeddings
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)

        # Re-add the depot to the embeddings
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)

        return out