"""
Trajectory Encoder/Decoder for learning latent embeddings of state-action sequences

Works with RELATIVE coordinates where s0 is always [0,0,0,...]:
- Encodes (s0_rel, a0, s1_rel, a1, ..., s_{T-1}_rel, a_{T-1}) -> z
- Decodes z -> (a0, s1_rel, a1, s2_rel, ..., a_{T-1})

No s0 input needed for decoding since s0_rel = [0,0,0,...] by construction!
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    """Standard positional encoding for transformer"""
    def __init__(self, d_model, max_len=215):  # 107 steps * 2 = 214 tokens + 1 buffer
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        return x + self.pe[:x.size(1), :].unsqueeze(0)


class TrajectoryEncoder(nn.Module):
    """
    Encoder: (s0, a0, s1, a1, s2, a2, s3, a3) -> z (2D embedding)
    Uses bidirectional transformer with state/action projectors
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # Project state and action to same token dimension
        self.state_proj = nn.Linear(state_dim, hidden_dim)
        self.action_proj = nn.Linear(action_dim, hidden_dim)

        # Positional encoding (107 steps * 2 = 214 tokens)
        self.pos_encoding = PositionalEncoding(hidden_dim, max_len=215)

        # Bidirectional transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection: aggregate tokens -> 2D latent
        self.to_latent_mean = nn.Linear(hidden_dim, latent_dim)
        self.to_latent_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, states, actions):
        """
        states: (batch, horizon, state_dim) = (B, 4, 22)
        actions: (batch, horizon, action_dim) = (B, 4, 8)

        Returns:
            z_mean: (batch, latent_dim) - mean of latent
            z_logvar: (batch, latent_dim) - log variance of latent
            z: (batch, latent_dim) - sampled latent
        """
        batch_size, horizon, _ = states.shape

        # Project states and actions to token embeddings
        state_tokens = self.state_proj(states)  # (B, 4, hidden_dim)
        action_tokens = self.action_proj(actions)  # (B, 4, hidden_dim)

        # Interleave: [s0, a0, s1, a1, s2, a2, s3, a3]
        tokens = torch.stack([state_tokens, action_tokens], dim=2)  # (B, 4, 2, hidden_dim)
        tokens = tokens.reshape(batch_size, horizon * 2, self.hidden_dim)  # (B, 8, hidden_dim)

        # Add positional encoding
        tokens = self.pos_encoding(tokens)  # (B, 8, hidden_dim)

        # Bidirectional transformer encoding
        encoded = self.transformer(tokens)  # (B, 8, hidden_dim)

        # Aggregate: use the last token as the trajectory representation
        traj_repr = encoded[:, -1, :]  # (B, hidden_dim)

        # Project to latent space with reparameterization
        z_mean = self.to_latent_mean(traj_repr)  # (B, latent_dim)
        z_logvar = self.to_latent_logvar(traj_repr)  # (B, latent_dim)

        # Reparameterization trick
        if self.training:
            std = torch.exp(0.5 * z_logvar)
            eps = torch.randn_like(std)
            z = z_mean + eps * std
        else:
            z = z_mean

        # NO tanh! Let z spread naturally in unbounded space.
        # KL loss will keep it reasonably close to N(0, I) prior.

        # Return: z_mean for KL loss and plotting, z for decoder/DTW
        return z_mean, z_logvar, z


class TrajectoryDecoder(nn.Module):
    """
    Simple MLP Decoder: z -> (a0, s1, a1, s2, ..., a_{T-1}, s_T)

    Generates the entire trajectory in one shot from z, with optional time conditioning.
    This forces all information to flow through z (no teacher forcing leak).

    NOTE: No s0 output needed! When using relative coordinates, s0_rel is always zero.
    The decoder predicts (s1, s2, ..., s_{T-1}) and (a0, a1, ..., a_{T-1}).
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=256, num_layers=4, num_heads=4, latent_dim=2, horizon=4):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.horizon = horizon

        # MLP that takes (z + time) and outputs full trajectory
        # Input: latent_dim + 1 (for time), Output: (state_dim + action_dim) per timestep
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim + action_dim)
        )

    def forward(self, z, states=None, actions=None):
        """
        z: (batch, latent_dim) - latent embedding
        states: IGNORED (kept for compatibility)
        actions: IGNORED (kept for compatibility)

        Returns:
            pred_states: (batch, horizon-1, state_dim) - predicted states (s1, s2, ..., s_{T-1})
            pred_actions: (batch, horizon, action_dim) - predicted actions (a0, a1, ..., a_{T-1})

        NOTE: Ground truth is NEVER used! Decoder only depends on z.
        """
        batch_size = z.shape[0]
        device = z.device

        # Create time conditioning: [0, 1/(T-1), 2/(T-1), ..., 1]
        t = torch.linspace(0, 1, self.horizon, device=device)  # (T,)
        t = t.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, -1)  # (B, T, 1)

        # Repeat z for each timestep
        z_repeated = z.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, latent_dim)

        # Concatenate z and time
        inp = torch.cat([z_repeated, t], dim=-1)  # (B, T, latent_dim + 1)

        # Generate trajectory
        out = self.mlp(inp)  # (B, T, state_dim + action_dim)

        # Split into states and actions
        pred_states_full = out[:, :, :self.state_dim]  # (B, T, state_dim)
        pred_actions = out[:, :, self.state_dim:]  # (B, T, action_dim)

        # For states, we predict s1, s2, ..., s_{T-1} (skip s0 which is always zero)
        # So we return states from timestep 1 onwards
        pred_states = pred_states_full[:, 1:, :]  # (B, T-1, state_dim)

        return pred_states, pred_actions


class TrajectoryVAE(nn.Module):
    """Complete VAE model combining encoder and decoder"""
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2, horizon=4):
        super().__init__()
        self.encoder = TrajectoryEncoder(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            latent_dim=latent_dim
        )
        self.decoder = TrajectoryDecoder(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            latent_dim=latent_dim,
            horizon=horizon
        )
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon

    def forward(self, states, actions):
        """
        states: (batch, horizon, state_dim) - RELATIVE coordinates
        actions: (batch, horizon, action_dim) - RELATIVE coordinates
        """
        # Encode
        z_mean, z_logvar, z = self.encoder(states, actions)

        # Decode (no s0 needed - it's always zero in relative coords!)
        pred_states, pred_actions = self.decoder(z, states, actions)

        return {
            'z_mean': z_mean,
            'z_logvar': z_logvar,
            'z': z,
            'pred_states': pred_states,
            'pred_actions': pred_actions
        }

    def encode(self, states, actions):
        """Encode trajectory to latent space (returns mean without tanh)"""
        z_mean, z_logvar, z = self.encoder(states, actions)
        # Return mean directly (no tanh)
        return z_mean

    def decode(self, z):
        """Decode latent to trajectory (no s0 needed!)"""
        return self.decoder(z)
