import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from src.pe.base import TourLayer


class TourLayer_SIN(TourLayer):
    def __init__(self, **model_params):
        super().__init__(**model_params)
        embedding_dim = self.embedding_dim

        # Sinusoidal positional embedding parameters
        self.fpe_base = float(model_params.get("fpe_base", 10_000.0))
        self.fpe_normalize = bool(model_params.get("fpe_normalize", False))

        # Precompute the inverse frequency vector for half the dimensions, like in the original Transformer.
        half = max(1, self.embedding_dim // 2)
        inv_freq = torch.exp(
            -math.log(self.fpe_base) * torch.arange(0, half, dtype=torch.float32) / half
        )  # shape: [half]
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.angle_embed = nn.Linear(2, embedding_dim)
        
        # Replace the 2D->D combiner with 3D->D (keep the name to minimize code changes elsewhere)
        self.tour_combiner = nn.Linear(embedding_dim * 4, embedding_dim)

    def _sinusoidal(self, pos: torch.Tensor) -> torch.Tensor:
        """
        Generate sinusoidal embeddings for positions.

        pos: [B, N] (float tensor)
        returns: [B, N, D]
        """
        B, N = pos.shape
        half = self.inv_freq.shape[0]
        # [B, N, half] = [B, N, 1] * [half] -> broadcast multiply
        angles = pos.unsqueeze(-1) * self.inv_freq  # radians

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        pe = torch.cat([sin, cos], dim=-1)  # [B, N, 2*half]

        # If D is odd, pad with zeros; if D > 2*half (only when D is odd), pad one column.
        if pe.shape[-1] < self.embedding_dim:
            pad = self.embedding_dim - pe.shape[-1]
            pe = F.pad(pe, (0, pad), mode="constant", value=0.0)

        return pe.to(dtype=torch.float32)

    def forward(
        self,
        batch_size: int,
        num_customers: int,
        tour_index: torch.Tensor,   # shape: [B, N], values in {0..T} or -1 for unvisited
        out: torch.Tensor,           # shape: [B, N+1, D], index 0 is depot
        pos_index: torch.Tensor,     # shape: [B, N], values in {0..T} or -1 for unvisited
        cur_dist: torch.Tensor,      # shape: [B, N], values in {0..T} or -1 for unvisited
        tour_angle: torch.Tensor     # shape: [B, N], values in {0..T} or -1 for unvisited
    ) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Find out if there exist unvisited customers (possible for example in the PCVRP). If yes,
        # create a dummy tour for them with index max_tour_index + 1
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index += 1
            all_nodes_visited = False

        # SECTION: >>> PE start >>>
        # 1) Build position embedding (sin/cos) as a separate branch; DO NOT overwrite customer_embeddings.
        pos = pos_index.to(torch.float32)

        # Optional normalization by tour length (if enabled)
        # if self.fpe_normalize:
        #     num_tours = int(max_tour_index.item()) + 1
        #     one_hot = F.one_hot(tour_index, num_classes=num_tours).to(torch.long)  # [B, N, T]
        #     tour_len_per_tour = one_hot.sum(dim=1)                                 # [B, T]
        #     tour_len = torch.gather(tour_len_per_tour, 1, tour_index)              # [B, N]
        #     denom = (tour_len - 1).clamp_min(1).to(torch.float32)
        #     pos = pos / denom

        # Positional embedding branch
        pos_emb = self._sinusoidal(pos).to(dtype=customer_embeddings.dtype)  # [B, N, D]

        # Zero-out PE for customers on the dummy tour (keep behavior consistent with original)
        if not all_nodes_visited:
            dummy_id = max_tour_index
            mask_dummy = (tour_index == dummy_id).unsqueeze(-1)              # [B, N, 1]
            pos_emb = pos_emb.masked_fill(mask_dummy, 0.0)

        # 2) Angle embedding branch (ensure float dtype)
        angle_in = tour_angle.to(dtype=customer_embeddings.dtype)            # [B, N]
        sin_cos = torch.stack([torch.sin(angle_in), torch.cos(angle_in)], dim=-1)  # [B, N, 2]
        angle_emb = self.angle_embed(sin_cos)                                # [B, N, D]
        # SECTION: <<< PE end <<<

        # Initialize tour embeddings
        tour_embeddings = torch.zeros(batch_size, max_tour_index + 1, customer_embeddings.shape[2],
                                      dtype=customer_embeddings.dtype)

        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(1, tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
                                     customer_embeddings)

        # If a dummy tour was used for unvisited customers, set the tour embeddings for the dummy to all zero
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather the customer tour embeddings based on tour_index
        customer_tour_embeddings = torch.gather(tour_embeddings, 1,
                                                tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]))

        # SECTION: >>> PE Embedding Start >>>

        combined_embeddings = torch.cat(
            (customer_embeddings, customer_tour_embeddings, pos_emb, angle_emb), dim=2
        )
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 4)

        # SECTION: <<< PE Embedding End <<<

        # Apply the tour combiner layer and activation
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))

        # Apply feedforward layer
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Normalize and add embeddings
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)

        # Re-add the depot to the embeddings
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)

        return out