"""
Trajectory Encoder/Decoder for pick_place task with endpoint conditioning.

Key difference from close_drawer:
- Encoder: Same as close_drawer (encodes full state+action trajectory)
- Decoder: Conditioned on z + endpoint_direction + time (not full s0)

This prevents decoder from overfitting to s0 while still providing
minimal context about WHERE the trajectory goes. The HOW (trajectory shape)
must come from z.

Works with RELATIVE coordinates where s0 is always [0,0,0,...]:
- Encodes (s0_rel, a0, s1_rel, a1, ..., s_{T-1}_rel, a_{T-1}) -> z
- Decodes (z, endpoint_direction) -> (a0, s1_rel, a1, s2_rel, ..., a_{T-1})
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    """Standard positional encoding for transformer"""
    def __init__(self, d_model, max_len=175):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        return x + self.pe[:x.size(1), :].unsqueeze(0)


class TrajectoryEncoder(nn.Module):
    """
    Encoder: (s0, a0, s1, a1, ..., s_{T-1}, a_{T-1}) -> z (2D embedding)
    Uses bidirectional transformer with state/action projectors.

    Same as close_drawer encoder.
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # Project state and action to same token dimension
        self.state_proj = nn.Linear(state_dim, hidden_dim)
        self.action_proj = nn.Linear(action_dim, hidden_dim)

        # Positional encoding (max_len for long trajectories: 64 steps * 2 = 128 tokens)
        self.pos_encoding = PositionalEncoding(hidden_dim, max_len=256)

        # Bidirectional transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection: aggregate tokens -> latent
        self.to_latent_mean = nn.Linear(hidden_dim, latent_dim)
        self.to_latent_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, states, actions):
        """
        states: (batch, horizon, state_dim) - s0-relative coordinates
        actions: (batch, horizon, action_dim) - s0-relative coordinates

        Returns:
            z_mean: (batch, latent_dim) - mean of latent
            z_logvar: (batch, latent_dim) - log variance of latent
            z: (batch, latent_dim) - sampled latent
        """
        batch_size, horizon, _ = states.shape

        # Project states and actions to token embeddings
        state_tokens = self.state_proj(states)  # (B, T, hidden_dim)
        action_tokens = self.action_proj(actions)  # (B, T, hidden_dim)

        # Interleave: [s0, a0, s1, a1, ...]
        tokens = torch.stack([state_tokens, action_tokens], dim=2)  # (B, T, 2, hidden_dim)
        tokens = tokens.reshape(batch_size, horizon * 2, self.hidden_dim)  # (B, 2T, hidden_dim)

        # Add positional encoding
        tokens = self.pos_encoding(tokens)

        # Bidirectional transformer encoding
        encoded = self.transformer(tokens)  # (B, 2T, hidden_dim)

        # Aggregate: use the last token as the trajectory representation
        traj_repr = encoded[:, -1, :]  # (B, hidden_dim)

        # Project to latent space with reparameterization
        z_mean = self.to_latent_mean(traj_repr)
        z_logvar = self.to_latent_logvar(traj_repr)

        # Reparameterization trick
        if self.training:
            std = torch.exp(0.5 * z_logvar)
            eps = torch.randn_like(std)
            z = z_mean + eps * std
        else:
            z = z_mean

        return z_mean, z_logvar, z


class TrajectoryDecoderWithEndpoint(nn.Module):
    """
    Decoder conditioned on z + endpoint_direction + time.

    endpoint_direction: normalized (end_ee - start_ee), 3D
    This tells decoder WHERE to go, but HOW (trajectory shape) must come from z.

    Key insight: By only providing direction (not magnitude or full s0),
    the decoder cannot memorize specific trajectories based on s0.
    It must rely on z for the trajectory shape/mode.
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=256, latent_dim=2, horizon=64):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.horizon = horizon

        # Input: z(latent_dim) + endpoint_dir(3) + time(1)
        input_dim = latent_dim + 3 + 1

        # MLP decoder
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim + action_dim)
        )

    def forward(self, z, endpoint_direction):
        """
        z: (B, latent_dim) - latent embedding
        endpoint_direction: (B, 3) - normalized direction from start_ee to end_ee

        Returns:
            pred_states: (B, horizon-1, state_dim) - predicted states (s1, s2, ..., s_{T-1})
            pred_actions: (B, horizon, action_dim) - predicted actions (a0, a1, ..., a_{T-1})
        """
        batch_size = z.shape[0]
        device = z.device

        # Time conditioning: [0, 1/(T-1), ..., 1]
        t = torch.linspace(0, 1, self.horizon, device=device)
        t = t.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, -1)  # (B, T, 1)

        # Repeat z and endpoint_direction for each timestep
        z_rep = z.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, latent_dim)
        dir_rep = endpoint_direction.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, 3)

        # Concatenate all inputs
        inp = torch.cat([z_rep, dir_rep, t], dim=-1)  # (B, T, latent_dim + 3 + 1)

        # Generate trajectory
        out = self.mlp(inp)  # (B, T, state_dim + action_dim)

        # Split into states and actions
        pred_states_full = out[:, :, :self.state_dim]  # (B, T, state_dim)
        pred_actions = out[:, :, self.state_dim:]  # (B, T, action_dim)

        # For states, we predict s1, s2, ..., s_{T-1} (skip s0 which is always zero in relative coords)
        pred_states = pred_states_full[:, 1:, :]  # (B, T-1, state_dim)

        return pred_states, pred_actions


class TrajectoryVAEWithEndpoint(nn.Module):
    """
    Complete VAE model for pick_place task.

    Encoder: Full state+action trajectory -> z
    Decoder: z + endpoint_direction + time -> state+action trajectory

    The endpoint_direction provides minimal context about WHERE the trajectory goes,
    while z must encode HOW (the trajectory shape/mode).
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2, horizon=64):
        super().__init__()
        self.encoder = TrajectoryEncoder(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            latent_dim=latent_dim
        )
        self.decoder = TrajectoryDecoderWithEndpoint(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim * 2,  # Larger decoder for harder task
            latent_dim=latent_dim,
            horizon=horizon
        )
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon
        self.latent_dim = latent_dim

    def forward(self, states_rel, actions_rel, endpoint_direction):
        """
        states_rel: (B, T, state_dim) - s0-relative states
        actions_rel: (B, T, action_dim) - s0-relative actions
        endpoint_direction: (B, 3) - normalized direction from start_ee to end_ee

        Returns dict with z_mean, z_logvar, z, pred_states, pred_actions
        """
        # Encode
        z_mean, z_logvar, z = self.encoder(states_rel, actions_rel)

        # Decode with endpoint conditioning
        pred_states, pred_actions = self.decoder(z, endpoint_direction)

        return {
            'z_mean': z_mean,
            'z_logvar': z_logvar,
            'z': z,
            'pred_states': pred_states,
            'pred_actions': pred_actions
        }

    def encode(self, states_rel, actions_rel):
        """Encode trajectory to latent space (returns mean)"""
        z_mean, _, _ = self.encoder(states_rel, actions_rel)
        return z_mean

    def decode(self, z, endpoint_direction):
        """Decode latent to trajectory with endpoint conditioning"""
        return self.decoder(z, endpoint_direction)
