import torch
import torch.nn as nn
import torch.nn.functional as F

from src.pe.base import TourLayer


class TourLayer_RPE(TourLayer):
    def __init__(self, **model_params):
        super().__init__(**model_params)
        embedding_dim = self.embedding_dim

        # Tiny projector from a scalar relative position r∈[0,1] to D-dim embedding.
        self.rel_pos_proj = nn.Linear(1, embedding_dim)

        self.angle_embed = nn.Linear(2, embedding_dim)

        # Replace the 2D->D combiner with 3D->D (keep the name to minimize code changes elsewhere)
        self.tour_combiner = nn.Linear(embedding_dim * 4, embedding_dim)

    def forward(
        self,
        batch_size: int,
        num_customers: int,
        tour_index: torch.Tensor,   # shape: [B, N], values in {0..T} or -1 for unvisited
        out: torch.Tensor,           # shape: [B, N+1, D], index 0 is depot
        pos_index: torch.Tensor,     # shape: [B, N], values in {0..T} or -1 for unvisited
        cur_dist: torch.Tensor,      # shape: [B, N], values in {0..T} or -1 for unvisited
        tour_angle: torch.Tensor     # shape: [B, N], values in {0..T} or -1 for unvisited
    ) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Find out if there exist unvisited customers (possible for example in the PCVRP). If yes,
        # create a dummy tour for them with index max_tour_index + 1
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index += 1
            all_nodes_visited = False

        # SECTION: >>> PE start >>>
        
        # Compute tour lengths for normalization
        num_tours = int(max_tour_index.item()) + 1  # includes dummy if created above
        one_hot = F.one_hot(tour_index, num_classes=num_tours).to(torch.long)  # [B, N, T]
        tour_len_per_tour = one_hot.sum(dim=1)  # [B, T]
        tour_len = torch.gather(tour_len_per_tour, 1, tour_index)  # [B, N]
        
        # Compute relative position: pos_index / max(tour_len-1, 1)
        # This normalizes positions to [0, 1] range, invariant to tour length
        denom = (tour_len - 1).clamp_min(1).to(torch.float32)
        rel_pos = (pos_index.to(torch.float32) / denom).unsqueeze(-1)  # [B, N, 1]
        
        # Project relative position to embedding dimension
        rel_pos_emb = self.rel_pos_proj(rel_pos)  # [B, N, D]
        
        # If a dummy tour exists, do not inject PE for those customers (keep them neutral as before).
        if not all_nodes_visited:
            dummy_id = max_tour_index  # last id is the dummy tour
            mask_dummy = (tour_index == dummy_id).unsqueeze(-1)  # [B, N, 1]
            rel_pos_emb = rel_pos_emb.masked_fill(mask_dummy, 0.0)
        
        rel_pos_emb = rel_pos_emb.to(dtype=customer_embeddings.dtype)

        sin_cos = torch.stack([torch.sin(tour_angle), torch.cos(tour_angle)], dim=-1)
        angle_emb = self.angle_embed(sin_cos)

        # SECTION: <<< PE end <<<

        # Initialize tour embeddings
        tour_embeddings = torch.zeros(batch_size, max_tour_index + 1, customer_embeddings.shape[2],
                                      dtype=customer_embeddings.dtype)

        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(1, tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
                                     customer_embeddings)

        # If a dummy tour was used for unvisited customers, set the tour embeddings for the dummy to all zero
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather the customer tour embeddings based on tour_index
        customer_tour_embeddings = torch.gather(tour_embeddings, 1,
                                                tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]))

        # SECTION: >>> PE Embedding Start >>>

        combined_embeddings = torch.cat(
            (customer_embeddings, customer_tour_embeddings, rel_pos_emb, angle_emb), dim=2
        )
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 4)

        # SECTION: <<< PE Embedding End <<<

        # Apply the tour combiner layer and activation
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))

        # Apply feedforward layer
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Normalize and add embeddings
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)

        # Re-add the depot to the embeddings
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)

        return out