import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.pe.base import TourLayer


class TourLayer_HADES_NORM_SIN(TourLayer):


    def __init__(self, **model_params):
        super().__init__(**model_params)
        D = self.embedding_dim

        # Fourier features scale for r in [0,1]. Larger -> higher frequency variation.
        self.dist_fpe_base = float(model_params.get("dist_fpe_base", 10_000.0))

        # Precompute log-spaced inverse frequencies like Transformer, but used on r \in [0,1].
        half = max(1, D // 2)
        inv_freq = torch.exp(
            -math.log(self.dist_fpe_base) * torch.arange(0, half, dtype=torch.float32) / half
        )  # [half]
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Angle branch: match your baseline combiner shape (4D -> D)
        self.angle_embed = nn.Linear(2, D)

        # Combiner (4 branches -> D)
        self.tour_combiner = nn.Linear(D * 4, D)

        # Optional small gate on the distance PE (can help stabilize training)
        self.pos_gate = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))

    def _fourier_from_r(self, r: torch.Tensor, out_dtype: torch.dtype) -> torch.Tensor:
        """
        r: [B, N] in [0,1]
        returns: [B, N, D] sinusoidal features (sin/cos packed), padded if D is odd
        """
        B, N = r.shape
        half = self.inv_freq.shape[0]
        angles = r.unsqueeze(-1) * self.inv_freq  # [B, N, half]

        s = torch.sin(angles)
        c = torch.cos(angles)
        pe = torch.cat([s, c], dim=-1)  # [B, N, 2*half]

        D = self.embedding_dim
        if pe.shape[-1] < D:
            pe = F.pad(pe, (0, D - pe.shape[-1]), value=0.0)
        return pe.to(dtype=out_dtype)

    def forward(
        self,
        batch_size: int,
        num_customers: int,
        tour_index: torch.Tensor,   # [B, N], -1 for unvisited
        out: torch.Tensor,          # [B, N+1, D]
        pos_index: torch.Tensor,    # [B, N] (kept for signature parity; unused here)
        cur_dist: torch.Tensor,     # [B, N] distance-so-far within tour
        tour_angle: torch.Tensor    # [B, N]
    ) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]  # [B, N, D]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Handle unvisited customers with a dummy tour
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index = tour_index.clone()
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index = tour_index.max()
            all_nodes_visited = False

        # SECTION: >>> PE start >>>
        # Build a *relative* distance r in [0,1] per tour, then map via Fourier features.

        device = customer_embeddings.device
        dtype = customer_embeddings.dtype
        B, N, D = customer_embeddings.shape
        num_tours = int(max_tour_index.item()) + 1  # includes dummy if created above

        # cur_dist: [B, N] -> [B, N, 1] for scatter_add
        cur_dist_ = cur_dist.to(dtype=dtype, device=device).unsqueeze(-1)        # [B, N, 1]

        # Compute per-tour total distance via scatter_add
        tour_total_dist = torch.zeros(B, num_tours, 1, dtype=dtype, device=device)  # [B, T, 1]
        tour_total_dist.scatter_add_(
            1,
            tour_index[:, :, None],
            cur_dist_,
        )  # sum of cur_dist within each tour

        # Gather total per customer and normalize -> r in [0,1]
        tour_total_per_customer = torch.gather(  # [B, N, 1]
            tour_total_dist, 1, tour_index[:, :, None]
        )
        eps = torch.tensor(1e-8, dtype=dtype, device=device)
        r = (cur_dist_ / (tour_total_per_customer + eps)).squeeze(-1).clamp(0, 1)  # [B, N]

        # Optional: zero out r for dummy tour nodes (keeps neutrality consistent)
        if not all_nodes_visited:
            dummy_id = max_tour_index
            r = r.masked_fill((tour_index == dummy_id), 0.0)

        # Functional (sin/cos) positional embedding from r
        pos_emb = self._fourier_from_r(r, out_dtype=dtype) * self.pos_gate.to(dtype)  # [B, N, D]

        # Angle branch (same as your current)
        angle_in = tour_angle.to(dtype=dtype)
        sin_cos = torch.stack([torch.sin(angle_in), torch.cos(angle_in)], dim=-1)  # [B, N, 2]
        angle_emb = self.angle_embed(sin_cos)                                      # [B, N, D]
        # SECTION: <<< PE end <<<

        # Initialize tour embeddings
        tour_embeddings = torch.zeros(batch_size, max_tour_index + 1, customer_embeddings.shape[2],
                                      dtype=customer_embeddings.dtype)

        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(1, tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
                                     customer_embeddings)

        # If a dummy tour was used for unvisited customers, set the tour embeddings for the dummy to all zero
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather the customer tour embeddings based on tour_index
        customer_tour_embeddings = torch.gather(tour_embeddings, 1,
                                                tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]))

        # SECTION: >>> PE Embedding Start >>>
        combined_embeddings = torch.cat(
            (customer_embeddings, customer_tour_embeddings, pos_emb, angle_emb), dim=2
        )
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 4)
        # SECTION: <<< PE Embedding End <<<

        # Fuse
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Residual + norm w.r.t. content branch
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)

        # Re-add depot
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)
        return out