"""
EE Trajectory Encoder for pick-and-place task.

This encoder ONLY uses the normalized end-effector trajectory (3D: progress, perp1, perp2)
to learn latent embeddings. This ensures that REACH and CARRY trajectories with the same
control point parameters will produce the same embedding, since their normalized EE shapes
are identical.

Key difference from trajectory_encoder.py:
- Only uses 3D EE trajectory, not full 22D state or 8D actions
- No decoder (we only need encoding for z-conditioned policy)
- Simple transformer encoder on the 3D curve
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    """Standard positional encoding for transformer"""
    def __init__(self, d_model, max_len=128):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        return x + self.pe[:x.size(1), :].unsqueeze(0)


class EETrajectoryEncoder(nn.Module):
    """
    Encoder that takes ONLY the normalized EE trajectory (3D) and outputs z (2D).

    Input: ee_traj (B, T, 3) where each point is [progress, perp1_offset, perp2_offset]
    Output: z (B, latent_dim)

    This ensures REACH and CARRY with same CP produce identical z values.
    """
    def __init__(self, ee_dim=3, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2):
        super().__init__()
        self.ee_dim = ee_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # Project 3D EE points to hidden dimension
        self.ee_proj = nn.Linear(ee_dim, hidden_dim)

        # Positional encoding for temporal information
        self.pos_encoding = PositionalEncoding(hidden_dim, max_len=128)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection to latent space (VAE style)
        self.to_latent_mean = nn.Linear(hidden_dim, latent_dim)
        self.to_latent_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, ee_traj):
        """
        ee_traj: (B, T, 3) - normalized EE trajectory [progress, perp1, perp2]

        Returns:
            z_mean: (B, latent_dim)
            z_logvar: (B, latent_dim)
            z: (B, latent_dim) - sampled (training) or mean (eval)
        """
        # Project to hidden dimension
        tokens = self.ee_proj(ee_traj)  # (B, T, hidden_dim)

        # Add positional encoding
        tokens = self.pos_encoding(tokens)

        # Transformer encoding
        encoded = self.transformer(tokens)  # (B, T, hidden_dim)

        # Use last token as trajectory representation
        traj_repr = encoded[:, -1, :]  # (B, hidden_dim)

        # Project to latent space
        z_mean = self.to_latent_mean(traj_repr)
        z_logvar = self.to_latent_logvar(traj_repr)

        # Reparameterization trick
        if self.training:
            std = torch.exp(0.5 * z_logvar)
            eps = torch.randn_like(std)
            z = z_mean + eps * std
        else:
            z = z_mean

        return z_mean, z_logvar, z


class EETrajectoryDecoder(nn.Module):
    """
    Decoder that reconstructs the normalized EE trajectory from z.

    Input: z (B, latent_dim)
    Output: ee_traj (B, T, 3)
    """
    def __init__(self, ee_dim=3, hidden_dim=128, latent_dim=2, horizon=64):
        super().__init__()
        self.ee_dim = ee_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.horizon = horizon

        # MLP decoder with time conditioning
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim + 1, hidden_dim),  # +1 for time
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, ee_dim)
        )

    def forward(self, z):
        """
        z: (B, latent_dim)
        Returns: ee_traj (B, T, 3)
        """
        batch_size = z.shape[0]
        device = z.device

        # Time conditioning: [0, 1/(T-1), ..., 1]
        t = torch.linspace(0, 1, self.horizon, device=device)
        t = t.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, -1)  # (B, T, 1)

        # Repeat z for each timestep
        z_repeated = z.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, latent_dim)

        # Concatenate z and time
        inp = torch.cat([z_repeated, t], dim=-1)  # (B, T, latent_dim + 1)

        # Generate trajectory
        ee_traj = self.mlp(inp)  # (B, T, 3)

        return ee_traj


class EETrajectoryVAE(nn.Module):
    """
    Complete VAE for normalized EE trajectories.

    Only encodes/decodes the 3D normalized EE curve [progress, perp1, perp2].
    This ensures REACH and CARRY with same CP params produce identical embeddings.
    """
    def __init__(self, ee_dim=3, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2, horizon=64):
        super().__init__()
        self.encoder = EETrajectoryEncoder(
            ee_dim=ee_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            latent_dim=latent_dim
        )
        self.decoder = EETrajectoryDecoder(
            ee_dim=ee_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            horizon=horizon
        )
        self.ee_dim = ee_dim
        self.horizon = horizon

    def forward(self, ee_traj):
        """
        ee_traj: (B, T, 3) - normalized EE trajectory

        Returns dict with z_mean, z_logvar, z, pred_ee_traj
        """
        z_mean, z_logvar, z = self.encoder(ee_traj)
        pred_ee_traj = self.decoder(z)

        return {
            'z_mean': z_mean,
            'z_logvar': z_logvar,
            'z': z,
            'pred_ee_traj': pred_ee_traj
        }

    def encode(self, ee_traj):
        """Encode trajectory to latent space (returns mean)"""
        z_mean, _, _ = self.encoder(ee_traj)
        return z_mean

    def decode(self, z):
        """Decode latent to trajectory"""
        return self.decoder(z)
