"""
Trajectory Encoder/Decoder for pick_place task with normalized EE input.

KEY DESIGN - solving the mode collapse problem:
- Encoder input: Trajectory-frame normalized EE (3D: progress, perp1, perp2)
  This is IDENTICAL for REACH and CARRY with the same CP params!
- Decoder output: Full state (22D) + action (8D) trajectory
  This is a MUCH HARDER task than reconstructing the 3D curve.

Why this works:
1. Encoder sees identical input for REACH/CARRY with same CP → produces identical z
2. Decoder must reconstruct full 30D trajectory from 2D z → forces z to be discriminative
3. Decoder gets extra context (s0, trajectory_length) to make task feasible but not trivial

The decoder receives:
- z: 2D latent encoding trajectory shape (the CP pattern)
- s0: Initial state (22D) - tells decoder WHERE the trajectory starts
- trajectory_length: scalar - tells decoder how FAR to go
- time: [0,1] progress - tells decoder WHEN in the trajectory

This way, z encodes HOW the trajectory curves (the mode), while other inputs provide context.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class PositionalEncoding(nn.Module):
    """Standard positional encoding for transformer"""
    def __init__(self, d_model, max_len=175):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """x: (batch, seq_len, d_model)"""
        return x + self.pe[:x.size(1), :].unsqueeze(0)


class NormalizedEEEncoder(nn.Module):
    """
    Encoder: Trajectory-frame normalized EE trajectory (3D) -> z (2D)

    Input: (progress, perp1_offset, perp2_offset) at each timestep
    This representation is IDENTICAL for REACH and CARRY with the same CP.

    Uses bidirectional transformer to capture trajectory shape.
    """
    def __init__(self, ee_dim=3, hidden_dim=128, num_layers=4, num_heads=4, latent_dim=2):
        super().__init__()
        self.ee_dim = ee_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # Project EE to hidden dim
        self.ee_proj = nn.Linear(ee_dim, hidden_dim)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(hidden_dim, max_len=256)

        # Bidirectional transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output projection to latent space
        self.to_latent_mean = nn.Linear(hidden_dim, latent_dim)
        self.to_latent_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, ee_normalized):
        """
        ee_normalized: (B, T, 3) - trajectory-frame normalized EE [progress, perp1, perp2]

        Returns:
            z_mean: (B, latent_dim)
            z_logvar: (B, latent_dim)
            z: (B, latent_dim) - sampled (training) or mean (eval)
        """
        batch_size = ee_normalized.shape[0]

        # Project to hidden dimension
        tokens = self.ee_proj(ee_normalized)  # (B, T, hidden_dim)

        # Add positional encoding
        tokens = self.pos_encoding(tokens)

        # Bidirectional transformer encoding
        encoded = self.transformer(tokens)  # (B, T, hidden_dim)

        # Aggregate: use mean pooling over sequence
        traj_repr = encoded.mean(dim=1)  # (B, hidden_dim)

        # Project to latent space
        z_mean = self.to_latent_mean(traj_repr)
        z_logvar = self.to_latent_logvar(traj_repr)

        # Reparameterization trick
        if self.training:
            std = torch.exp(0.5 * z_logvar)
            eps = torch.randn_like(std)
            z = z_mean + eps * std
        else:
            z = z_mean

        return z_mean, z_logvar, z


class FullTrajectoryDecoder(nn.Module):
    """
    Decoder: z + s0 + trajectory_length + time -> full state + action trajectory

    The decoder must reconstruct the FULL 30D (22D state + 8D action) trajectory.
    This is much harder than reconstructing just the 3D normalized EE curve.

    Inputs:
    - z (2D): Encodes trajectory shape/mode (the CP pattern)
    - s0 (22D): Initial state - tells decoder WHERE the trajectory starts
    - trajectory_length (1D): Distance from start to end - tells decoder HOW FAR
    - time (1D): Progress [0,1] - tells decoder WHEN in trajectory

    This design ensures z must encode the trajectory SHAPE (how it curves),
    since s0 and length only provide endpoint information, not the path.
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=256, latent_dim=2, horizon=64):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.horizon = horizon

        # Input: z(2) + s0(22) + traj_length(1) + time(1) = 26
        input_dim = latent_dim + state_dim + 1 + 1

        # MLP decoder with larger hidden dim for harder task
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, state_dim + action_dim)
        )

    def forward(self, z, s0, trajectory_length):
        """
        z: (B, latent_dim) - trajectory shape embedding
        s0: (B, state_dim) - initial state (absolute coordinates)
        trajectory_length: (B, 1) - distance from start EE to end EE

        Returns:
            pred_states: (B, T-1, state_dim) - predicted states s1...s_{T-1} (s0-relative)
            pred_actions: (B, T, action_dim) - predicted actions a0...a_{T-1} (s0-relative)
        """
        batch_size = z.shape[0]
        device = z.device

        # Time conditioning: [0, 1/(T-1), ..., 1]
        t = torch.linspace(0, 1, self.horizon, device=device)
        t = t.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, -1)  # (B, T, 1)

        # Repeat inputs for each timestep
        z_rep = z.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, latent_dim)
        s0_rep = s0.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, state_dim)
        length_rep = trajectory_length.unsqueeze(1).expand(-1, self.horizon, -1)  # (B, T, 1)

        # Concatenate all inputs
        inp = torch.cat([z_rep, s0_rep, length_rep, t], dim=-1)  # (B, T, input_dim)

        # Generate trajectory
        out = self.mlp(inp)  # (B, T, state_dim + action_dim)

        # Split into states and actions (predicted in s0-relative coordinates)
        pred_states_full = out[:, :, :self.state_dim]  # (B, T, state_dim)
        pred_actions = out[:, :, self.state_dim:]  # (B, T, action_dim)

        # Skip s0 prediction (it's always [0,0,...] in relative coords)
        pred_states = pred_states_full[:, 1:, :]  # (B, T-1, state_dim)

        return pred_states, pred_actions


class TrajectoryVAENormalizedEE(nn.Module):
    """
    VAE that encodes normalized EE and decodes full state+action trajectory.

    Key insight: REACH and CARRY with same CP have IDENTICAL normalized EE curves,
    so they will get IDENTICAL z embeddings. But the decoder task is hard enough
    (reconstruct full 30D trajectory) that z must be discriminative.

    This achieves:
    1. Same CP → same z (because encoder input is identical)
    2. Different CP → different z (because decoder needs z to be discriminative)
    3. Well-separated clusters (because DTW loss pushes different shapes apart)
    """
    def __init__(self, state_dim=22, action_dim=8, hidden_dim=128, num_layers=4,
                 num_heads=4, latent_dim=2, horizon=64):
        super().__init__()

        # Encoder: normalized EE (3D) -> z (2D)
        self.encoder = NormalizedEEEncoder(
            ee_dim=3,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            latent_dim=latent_dim
        )

        # Decoder: z + context -> full trajectory (30D)
        self.decoder = FullTrajectoryDecoder(
            state_dim=state_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim * 2,  # Larger for harder task
            latent_dim=latent_dim,
            horizon=horizon
        )

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.horizon = horizon
        self.latent_dim = latent_dim

    def forward(self, ee_normalized, s0, trajectory_length, states_rel, actions_rel):
        """
        ee_normalized: (B, T, 3) - trajectory-frame normalized EE
        s0: (B, state_dim) - initial state (absolute)
        trajectory_length: (B, 1) - EE trajectory length
        states_rel: (B, T, state_dim) - s0-relative states (for loss computation)
        actions_rel: (B, T, action_dim) - s0-relative actions (for loss computation)

        Returns dict with all outputs needed for loss computation.
        """
        # Encode normalized EE curve
        z_mean, z_logvar, z = self.encoder(ee_normalized)

        # Decode to full trajectory
        pred_states, pred_actions = self.decoder(z, s0, trajectory_length)

        return {
            'z_mean': z_mean,
            'z_logvar': z_logvar,
            'z': z,
            'pred_states': pred_states,  # (B, T-1, state_dim)
            'pred_actions': pred_actions  # (B, T, action_dim)
        }

    def encode(self, ee_normalized):
        """Encode normalized EE trajectory to latent space (returns mean)"""
        z_mean, _, _ = self.encoder(ee_normalized)
        return z_mean

    def decode(self, z, s0, trajectory_length):
        """Decode latent to full trajectory"""
        return self.decoder(z, s0, trajectory_length)
