"""Latent flow models for continuous normalizing flows.

This module implements flow-based models that learn to transform latent
representations using neural ODEs and continuous normalizing flows.
"""

from dataclasses import dataclass
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def sinusoidal_time_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
    """Create sinusoidal time embeddings for flow models.

    Generates sinusoidal embeddings similar to positional encodings, used to
    condition flow models on time steps.

    Args:
        t: Time tensor of shape (...,) with values in [0, 1]
        dim: Embedding dimension

    Returns:
        Time embeddings of shape (..., dim)
    """
    half = dim // 2
    freqs = torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half, device=t.device))
    args = t[..., None] * freqs[None, :]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb


@dataclass
class FlowConfig:
    """Configuration for LatentFlowPrior model.

    Attributes:
        K: Number of latent tokens
        d_latent: Latent dimension per token
        d_model: Model dimension (embedding size)
        nhead: Number of attention heads
        layers: Number of Transformer layers
        dim_ff: Feedforward dimension
        dropout: Dropout probability
        time_dim: Time embedding dimension
    """

    K: int = 8
    d_latent: int = 128
    d_model: int = 256
    nhead: int = 8
    layers: int = 10
    dim_ff: int = 1024
    dropout: float = 0.1
    time_dim: int = 128


class LatentFlowPrior(nn.Module):
    """
    Token Transformer that predicts velocity v(z,t).
    Uses:
      - time embedding (sinusoidal) projected to d_model and added to tokens
    """

    def __init__(self, cfg: FlowConfig):
        super().__init__()
        self.cfg = cfg
        self.z_in = nn.Linear(cfg.d_latent, cfg.d_model)
        self.z_out = nn.Linear(cfg.d_model, cfg.d_latent)

        layer = nn.TransformerEncoderLayer(
            d_model=cfg.d_model,
            nhead=cfg.nhead,
            dim_feedforward=cfg.dim_ff,
            dropout=cfg.dropout,
            batch_first=True,
            activation="gelu",
        )
        self.tr = nn.TransformerEncoder(layer, num_layers=cfg.layers)

        self.time_proj = nn.Linear(cfg.time_dim, cfg.d_model)
        self.ln = nn.LayerNorm(cfg.d_model)

    def forward(
        self, z: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Predict velocity field for continuous normalizing flow.

        Computes the velocity v(z, t) that defines the flow dynamics.
        The velocity is used in the ODE: dz/dt = v(z, t).

        Args:
            z: Latent tokens of shape (batch, K, d_latent)
            t: Time values of shape (batch,) in [0, 1]
            y: Optional conditional properties of shape (batch, n_properties).
               Currently unused but kept for API compatibility.

        Returns:
            Velocity tensor of shape (batch, K, d_latent)
        """
        B, K, D = z.shape
        h = self.z_in(z)  # (B,K,d_model)
        temb = sinusoidal_time_embedding(t, self.cfg.time_dim)
        ht = self.time_proj(temb)[:, None, :].expand(B, K, -1)
        h = h + ht

        h = self.ln(h)
        h = self.tr(h)
        v = self.z_out(h)
        return v
