"""Position embeddings."""

import math
import torch
import torch.nn as nn


class FourierFeatures(nn.Module):
    def __init__(self, dim: int, max_freq: float = 10.0):
        super().__init__()
        freqs = torch.exp(torch.linspace(0, math.log(max_freq), dim // 2))
        self.register_buffer("freqs", freqs)

    def forward(self, x):
        x = x.unsqueeze(-1) * self.freqs
        return torch.cat([x.sin(), x.cos()], dim=-1)


class DeltaPositionEmbedding(nn.Module):
    """Position embeddings based on calendar distance from endpoint."""

    def __init__(
        self,
        dim: int,
        max_delta: int = 600,
        num_slots: int = 24,
        use_fourier: bool = True,
        max_freq: float = 10.0,
    ):
        super().__init__()
        self.dim = dim
        self.max_delta = max_delta
        self.num_slots = num_slots
        self.use_fourier = use_fourier

        if use_fourier:
            self.delta_encoder = nn.Sequential(
                FourierFeatures(dim, max_freq=max_freq),
                nn.Linear(dim, dim),
            )
        else:
            self.delta_embed = nn.Embedding(max_delta + 1, dim)
        self.slot_embed = nn.Parameter(torch.randn(1, 1, num_slots, dim) * 0.02)
        self.cls_embed = nn.Parameter(torch.randn(1, 1, dim) * 0.02)

    def forward(self, L: int, K: int, stride: int = 1, include_cls: bool = True, device=None):
        if device is None:
            device = self.cls_embed.device
        time_indices = torch.arange(L, device=device)
        deltas = (L - 1 - time_indices) * stride
        deltas = deltas.float()
        if not self.use_fourier:
            deltas = deltas.clamp(max=self.max_delta).long()
        if self.use_fourier:
            delta_emb = self.delta_encoder(deltas)
        else:
            delta_emb = self.delta_embed(deltas)
        delta_emb = delta_emb.unsqueeze(1)
        slot_emb = self.slot_embed[:, :, :K, :]
        pos = delta_emb + slot_emb.squeeze(0)
        pos = pos.reshape(1, L * K, self.dim)
        if include_cls:
            pos = torch.cat([self.cls_embed, pos], dim=1)
        return pos

    def get_cls_embed(self):
        return self.cls_embed

