import torch
import torch.nn as nn
from torch import Tensor

from ..utils import pairwise_distances
from .alibi import ALiBi
from .layers import EncoderLayer


class Model(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        ff_dim: int,
        n_heads: int,
        n_layers: int,
        use_alibi: bool = False,
        use_coords: bool = False,
        use_random_ids: bool = False,
        use_rope: bool = False,
        use_ssmax: bool = False,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.use_alibi = use_alibi
        self.use_coords = use_coords
        self.use_random_ids = use_random_ids
        self.use_rope = use_rope
        self.use_ssmax = use_ssmax

        self.alibi = ALiBi(n_heads)
        self.tokens = nn.ParameterDict(
            {
                "origin": nn.Parameter(torch.randn((hidden_dim,))),
                "destination": nn.Parameter(torch.randn((hidden_dim,))),
            }
        )
        if self.use_coords:
            self.proj_coords = nn.Linear(2, hidden_dim)
        self.layers = nn.ModuleList(
            [EncoderLayer(hidden_dim, ff_dim, n_heads) for _ in range(n_layers)]
        )
        self.out_query = nn.Linear(hidden_dim, hidden_dim)
        self.out_key = nn.Linear(hidden_dim, hidden_dim)

    @torch.compile
    def forward(self, c: Tensor, m: Tensor, o: Tensor, e: Tensor) -> Tensor:
        """Predict the next city to visit.

        ---
        Args:
            c: Cities coordinates.
                Shape of [batch_size, n_cities, 2].
            m: To mask already visited cities.
                Shape of [batch_size, n_cities].
            o: Origin index.
                Shape of [batch_size,].
            e: Desintation index.
                Shape of [batch_size,].

        ---
        Returns:
            Probability over the next city to visit.
                Shape of [batch_size, n_cities].
        """
        batch_size, n_cities, _ = c.shape
        device = c.device

        m = torch.einsum("bl,bs->bls", m, m)
        s = torch.log(m.sum(dim=-1) + 1) if self.use_ssmax else None
        d = torch.vmap(pairwise_distances)(c) if self.use_alibi else None
        d = self.alibi(d) if d is not None else None
        p = c if self.use_rope else None
        x = torch.zeros((batch_size, n_cities, self.hidden_dim), device=device)
        b = torch.arange(batch_size)

        x[b, o] = x[b, o] + self.tokens["origin"]
        x[b, e] = x[b, e] + self.tokens["destination"]

        if self.use_random_ids:
            x = x + torch.randn((batch_size, n_cities, self.hidden_dim), device=device)

        if self.use_coords:
            x = x + self.proj_coords(c)

        for layer in self.layers:
            x = layer(x, m, s, d, p)

        # Predict logits of next city to visit.
        q = self.out_query(x[b, o])
        k = self.out_key(x)
        qk = torch.einsum("be,bse->bs", q, k)

        # Mask origins and destinations.
        m = m[b, o]
        m[b, o] = False
        m[b, e] = False
        qk = torch.where(m, qk, -torch.inf)

        return qk
