import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.pe.base import TourLayer


class TourLayer_ROPE(TourLayer):

    def __init__(self, **model_params):
        super().__init__(**model_params)
        D = self.embedding_dim

        # ---- RoPE params ----
        self.rope_base = float(model_params.get("rope_base", 10_000.0))
        self.rope_normalize = bool(model_params.get("rope_normalize", False))
        self.rope_theta_scale = float(model_params.get("rope_theta_scale", 1.0))

        # Pairing for rotation (even-odd dims). If odd D, leave last dim unrotated.
        if D % 2 != 0:
            self._has_tail = True
            self._paired = D - 1
        else:
            self._has_tail = False
            self._paired = D

        half = self._paired // 2
        inv_freq = torch.exp(-math.log(self.rope_base) * torch.arange(0, half, dtype=torch.float32) / half)
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Learnable anchor that RoPE will rotate to produce a D-dim positional feature.
        # Initialized small to avoid dominating content early on.
        self.rope_anchor = nn.Parameter(torch.randn(1, 1, D) * 0.02)

        # Angle branch (match APE)
        self.angle_embed = nn.Linear(2, D)

        # Combiner remains 4D -> D (customer, tour_sum, pos_emb_rope, angle_emb)
        self.tour_combiner = nn.Linear(D * 4, D)

    @staticmethod
    def _rotate_pairs(x_even, x_odd, cos, sin):
        x_rot_even = x_even * cos - x_odd * sin
        x_rot_odd  = x_even * sin + x_odd * cos
        return x_rot_even, x_rot_odd

    def _rope_rotate(self, x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
        """
        Apply RoPE rotation on the first _paired dims of x given scalar position pos.
        x:   [B, N, D]
        pos: [B, N]
        """
        B, N, D = x.shape
        H = self._paired // 2

        theta = pos.unsqueeze(-1) * self.inv_freq  # [B, N, H]
        # Optional NTK-like scaling
        if self.rope_theta_scale != 1.0:
            theta = theta * self.rope_theta_scale

        cos = torch.cos(theta)
        sin = torch.sin(theta)

        x_head = x[..., :self._paired]   # [B, N, _paired]
        x_even = x_head[..., 0::2]       # [B, N, H]
        x_odd  = x_head[..., 1::2]       # [B, N, H]

        x_even_rot, x_odd_rot = self._rotate_pairs(x_even, x_odd, cos, sin)
        x_rot_head = torch.stack((x_even_rot, x_odd_rot), dim=-1).reshape(B, N, self._paired)

        if self._has_tail:
            x_tail = x[..., self._paired:]   # [B, N, 1]
            x_rot = torch.cat([x_rot_head, x_tail], dim=-1)
        else:
            x_rot = x_rot_head
        return x_rot

    def forward(
        self,
        batch_size: int,
        num_customers: int,
        tour_index: torch.Tensor,   # [B, N]
        out: torch.Tensor,          # [B, N+1, D]
        pos_index: torch.Tensor,    # [B, N]
        cur_dist: torch.Tensor,     # [B, N]  (kept for signature parity; unused here)
        tour_angle: torch.Tensor    # [B, N]
    ) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]  # [B, N, D]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Handle unvisited customers via dummy tour
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index += 1
            all_nodes_visited = False

        # SECTION: >>> PE start >>>
        # Build RoPE positional branch + angle branch. Do NOT overwrite customer_embeddings.

        # Positions as float
        pos = pos_index.to(torch.float32)  # [B, N]

        # Optional normalization by tour length to get relative positions in [0,1]
        if self.rope_normalize:
            num_tours = int(max_tour_index.item()) + 1
            one_hot = F.one_hot(tour_index, num_classes=num_tours).to(torch.long)  # [B, N, T]
            tour_len_per_tour = one_hot.sum(dim=1)                                 # [B, T]
            tour_len = torch.gather(tour_len_per_tour, 1, tour_index)              # [B, N]
            denom = (tour_len - 1).clamp_min(1).to(torch.float32)
            pos = pos / denom

        # RoPE positional embedding: rotate a learnable anchor
        B, N, D = customer_embeddings.shape
        anchor = self.rope_anchor.to(dtype=customer_embeddings.dtype, device=customer_embeddings.device)
        anchor_expanded = anchor.expand(B, N, D)                                   # [B, N, D]
        pos_emb = self._rope_rotate(anchor_expanded, pos.to(anchor.dtype))         # [B, N, D]
        pos_emb = pos_emb.to(dtype=customer_embeddings.dtype)

        # Zero-out pos_emb for dummy tour customers (neutral behavior)
        if not all_nodes_visited:
            dummy_id = max_tour_index
            mask_dummy = (tour_index == dummy_id).unsqueeze(-1)                    # [B, N, 1]
            pos_emb = pos_emb.masked_fill(mask_dummy, 0.0)

        # Angle branch (match APE)
        angle_in = tour_angle.to(dtype=customer_embeddings.dtype)                   # [B, N]
        sin_cos = torch.stack([torch.sin(angle_in), torch.cos(angle_in)], dim=-1)   # [B, N, 2]
        angle_emb = self.angle_embed(sin_cos)                                       # [B, N, D]
        # SECTION: <<< PE end <<<

        # Initialize tour embeddings (sum over customers belonging to that tour)
        tour_embeddings = torch.zeros(
            batch_size,
            max_tour_index + 1,
            customer_embeddings.shape[2],
            dtype=customer_embeddings.dtype
        )

        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(
            1,
            tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
            customer_embeddings,
        )

        # Zero out dummy tour embedding if present
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather per-customer tour embeddings
        customer_tour_embeddings = torch.gather(
            tour_embeddings, 1,
            tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2])
        )

        # SECTION: >>> PE Embedding Start >>>
        # Concat content, tour-sum, RoPE-pos, angle -> 4*D
        combined_embeddings = torch.cat(
            (customer_embeddings, customer_tour_embeddings, pos_emb, angle_emb), dim=2
        )
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 4)
        # SECTION: <<< PE Embedding End <<<

        # Fuse and return (unchanged)
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))
        combined_embeddings = self.feedforward_layer(combined_embeddings)
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)
        return out