"""
Parameterized diffusion models with different conditioning architectures for trajectory embedding z.

Three main approaches:
1. FiLM (Feature-wise Linear Modulation): z modulates features via scale and shift
2. Cross-Attention: z is used as key/value in cross-attention layers
3. AdaLN (Adaptive Layer Normalization): z modulates layer norm parameters (like in DiT)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from typing import Optional, Dict

from model.common.mlp import MLP, ResidualMLP
from model.diffusion.modules import SinusoidalPosEmb


class FiLMLayer(nn.Module):
    """
    Feature-wise Linear Modulation layer.
    Modulates features using scale (gamma) and shift (beta) derived from conditioning.
    """
    def __init__(self, feature_dim: int, conditioning_dim: int):
        super().__init__()
        self.feature_dim = feature_dim
        self.scale_net = nn.Linear(conditioning_dim, feature_dim)
        self.shift_net = nn.Linear(conditioning_dim, feature_dim)

        # Initialize to neutral modulation: gamma=0, beta=0
        # This makes the initial output: x * (1 + 0) + 0 = x
        nn.init.zeros_(self.scale_net.weight)
        nn.init.zeros_(self.scale_net.bias)
        nn.init.zeros_(self.shift_net.weight)
        nn.init.zeros_(self.shift_net.bias)

    def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
        """
        x: (B, feature_dim)
        conditioning: (B, conditioning_dim)
        """
        scale = self.scale_net(conditioning)
        shift = self.shift_net(conditioning)
        return x * (1 + scale) + shift


class FiLMDiffusionMLP(nn.Module):
    """
    Diffusion MLP with FiLM conditioning on trajectory embedding z.

    FiLM modulates intermediate features at multiple layers using z.
    This version includes proper residual connections and LayerNorm for stability.

    Optionally includes target position conditioning (for multi-target tasks).
    """

    def __init__(
        self,
        action_dim,
        horizon_steps,
        cond_dim,
        z_dim,
        time_dim=16,
        mlp_dims=[512, 512, 512],
        cond_mlp_dims=None,
        z_mlp_dims=[64, 128],  # Process z before FiLM
        activation_type="Mish",
        out_activation_type="Identity",
        use_layernorm=True,  # Default to True for stability
        residual_style=True,  # Default to True for stability
        film_at_layers="all",  # "all", "middle", or list of layer indices
        target_dim=0,  # Target position dimension (0 to disable, 3 for xyz)
        start_dim=0,  # Start position dimension (0 to disable, 3 for xyz)
        end_dim=0,  # End position dimension (0 to disable, 3 for xyz)
    ):
        super().__init__()
        output_dim = action_dim * horizon_steps
        self.target_dim = target_dim
        self.start_dim = start_dim
        self.end_dim = end_dim

        # Time embedding
        self.time_embedding = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.Mish(),
            nn.Linear(time_dim * 2, time_dim),
        )

        # Process z for FiLM conditioning
        if z_mlp_dims is not None:
            self.z_processor = MLP(
                [z_dim] + z_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            z_processed_dim = z_mlp_dims[-1]
        else:
            self.z_processor = nn.Identity()
            z_processed_dim = z_dim

        # Optional MLP for conditioning state
        if cond_mlp_dims is not None:
            self.cond_mlp = MLP(
                [cond_dim] + cond_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            cond_processed_dim = cond_mlp_dims[-1]
        else:
            cond_processed_dim = cond_dim

        # Input includes z for direct path (like concat) plus modulation via FiLM
        # Also includes target/start/end positions if specified
        input_dim = time_dim + action_dim * horizon_steps + cond_processed_dim + z_processed_dim + target_dim + start_dim + end_dim

        # Determine which layers get FiLM
        num_layers = len(mlp_dims)
        if film_at_layers == "all":
            film_indices = list(range(num_layers))
        elif film_at_layers == "middle":
            film_indices = [num_layers // 2]
        else:
            film_indices = film_at_layers

        # Build residual blocks with FiLM
        self.input_proj = nn.Linear(input_dim, mlp_dims[0])

        self.blocks = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        self.film_layers = nn.ModuleDict()
        self.residual_style = residual_style
        self.use_layernorm = use_layernorm

        # Get activation function
        if activation_type == "Mish":
            self.activation = nn.Mish()
        elif activation_type == "ReLU":
            self.activation = nn.ReLU()
        elif activation_type == "SiLU":
            self.activation = nn.SiLU()
        else:
            self.activation = nn.Mish()

        for i in range(num_layers):
            in_dim = mlp_dims[i-1] if i > 0 else mlp_dims[0]
            out_dim = mlp_dims[i]

            self.blocks.append(nn.Linear(in_dim, out_dim))

            if use_layernorm:
                self.layer_norms.append(nn.LayerNorm(out_dim))

            if i in film_indices:
                self.film_layers[str(i)] = FiLMLayer(out_dim, z_processed_dim)

        # Output projection
        self.output_proj = nn.Linear(mlp_dims[-1], output_dim)

        self.time_dim = time_dim
        self.z_dim = z_dim
        self.z_processed_dim = z_processed_dim

        # Register z_empty as a buffer for CFG (initialized to zeros)
        self.register_buffer('z_empty', torch.zeros(1, z_dim))

    def forward(self, x, time, cond, z=None, **kwargs):
        """
        x: (B, Ta, Da) - noisy actions
        time: (B,) or int - diffusion step
        cond: dict with keys 'state' and optionally 'target'
        z: (B, z_dim) - trajectory embedding (None uses z_empty for unconditional)
        """
        B, Ta, Da = x.shape

        # Use z_empty if z is None (for CFG unconditional path)
        if z is None:
            z = self.z_empty.expand(B, -1)

        # Flatten actions
        x = x.view(B, -1)

        # Process inputs
        state = cond["state"].view(B, -1)
        if hasattr(self, "cond_mlp"):
            state = self.cond_mlp(state)

        # Process z for FiLM
        z_processed = self.z_processor(z)

        # Time embedding
        time = time.view(B, 1)
        time_emb = self.time_embedding(time).view(B, self.time_dim)

        # Concatenate all inputs including z (direct path)
        inputs = [x, time_emb, state, z_processed]

        # Add target position if available
        if self.target_dim > 0 and "target" in cond:
            target = cond["target"].view(B, -1)
            inputs.append(target)

        # Add start position if available (for styles setup)
        if self.start_dim > 0 and "start" in cond:
            start = cond["start"].view(B, -1)
            inputs.append(start)

        # Add end position if available (for styles setup)
        if self.end_dim > 0 and "end" in cond:
            end = cond["end"].view(B, -1)
            inputs.append(end)

        x = torch.cat(inputs, dim=-1)

        # Input projection
        x = self.input_proj(x)

        # Forward through residual blocks with FiLM modulation
        for i, block in enumerate(self.blocks):
            residual = x

            # Linear transformation
            x = block(x)

            # LayerNorm before FiLM (if enabled)
            if self.use_layernorm and i < len(self.layer_norms):
                x = self.layer_norms[i](x)

            # Apply FiLM modulation
            if str(i) in self.film_layers:
                x = self.film_layers[str(i)](x, z_processed)

            # Activation
            if i < len(self.blocks) - 1:  # No activation after last block
                x = self.activation(x)

            # Residual connection (if dimensions match)
            if self.residual_style and residual.shape == x.shape:
                x = x + residual

        # Output
        out = self.output_proj(x)
        return out.view(B, Ta, Da)


class CrossAttentionDiffusionMLP(nn.Module):
    """
    Diffusion MLP with cross-attention conditioning on trajectory embedding z.

    Uses z as key/value in cross-attention layers, allowing the model to
    attend to the trajectory embedding at multiple network depths.
    This version includes proper LayerNorm and residual connections for stability.
    """

    def __init__(
        self,
        action_dim,
        horizon_steps,
        cond_dim,
        z_dim,
        time_dim=16,
        mlp_dims=[512, 512, 512],
        cond_mlp_dims=None,
        z_mlp_dims=[64, 128],
        activation_type="Mish",
        out_activation_type="Identity",
        use_layernorm=True,  # Default to True for stability
        residual_style=True,  # Default to True for stability
        num_heads=8,
        cross_attn_at_layers="all",  # "all", "middle", or list of layer indices
    ):
        super().__init__()
        output_dim = action_dim * horizon_steps

        # Time embedding
        self.time_embedding = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.Mish(),
            nn.Linear(time_dim * 2, time_dim),
        )

        # Process z for cross-attention
        if z_mlp_dims is not None:
            self.z_processor = MLP(
                [z_dim] + z_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            z_processed_dim = z_mlp_dims[-1]
        else:
            self.z_processor = nn.Identity()
            z_processed_dim = z_dim

        # Optional MLP for conditioning state
        if cond_mlp_dims is not None:
            self.cond_mlp = MLP(
                [cond_dim] + cond_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            cond_processed_dim = cond_mlp_dims[-1]
        else:
            cond_processed_dim = cond_dim

        # Input includes z for direct path (like concat) plus modulation via attention
        input_dim = time_dim + action_dim * horizon_steps + cond_processed_dim + z_processed_dim

        # Determine which layers get cross-attention
        num_layers = len(mlp_dims)
        if cross_attn_at_layers == "all":
            attn_indices = list(range(num_layers))
        elif cross_attn_at_layers == "middle":
            attn_indices = [num_layers // 2]
        else:
            attn_indices = cross_attn_at_layers

        # Build residual blocks with cross-attention
        self.input_proj = nn.Linear(input_dim, mlp_dims[0])

        self.blocks = nn.ModuleList()
        self.block_norms = nn.ModuleList()
        self.cross_attn_layers = nn.ModuleDict()
        self.attn_norms = nn.ModuleDict()
        self.residual_style = residual_style
        self.use_layernorm = use_layernorm

        # Get activation function
        if activation_type == "Mish":
            self.activation = nn.Mish()
        elif activation_type == "ReLU":
            self.activation = nn.ReLU()
        elif activation_type == "SiLU":
            self.activation = nn.SiLU()
        else:
            self.activation = nn.Mish()

        for i in range(num_layers):
            in_dim = mlp_dims[i-1] if i > 0 else mlp_dims[0]
            out_dim = mlp_dims[i]

            self.blocks.append(nn.Linear(in_dim, out_dim))

            if use_layernorm:
                self.block_norms.append(nn.LayerNorm(out_dim))

            if i in attn_indices:
                # Use same dimension as feature for simpler attention
                attn_dim = out_dim
                self.cross_attn_layers[str(i)] = nn.ModuleDict({
                    'query_proj': nn.Linear(out_dim, attn_dim),
                    'key_proj': nn.Linear(z_processed_dim, attn_dim),
                    'value_proj': nn.Linear(z_processed_dim, attn_dim),
                    'attention': nn.MultiheadAttention(
                        embed_dim=attn_dim,
                        num_heads=num_heads,
                        batch_first=True,
                    ),
                    'out_proj': nn.Linear(attn_dim, out_dim)
                })
                # LayerNorm for attention output
                if use_layernorm:
                    self.attn_norms[str(i)] = nn.LayerNorm(out_dim)

        # Output projection
        self.output_proj = nn.Linear(mlp_dims[-1], output_dim)

        self.time_dim = time_dim
        self.z_dim = z_dim
        self.z_processed_dim = z_processed_dim

        # Register z_empty as a buffer for CFG (initialized to zeros)
        self.register_buffer('z_empty', torch.zeros(1, z_dim))

    def forward(self, x, time, cond, z=None, **kwargs):
        """
        x: (B, Ta, Da) - noisy actions
        time: (B,) or int - diffusion step
        cond: dict with key state
        z: (B, z_dim) - trajectory embedding (None uses z_empty for unconditional)
        """
        B, Ta, Da = x.shape

        # Use z_empty if z is None (for CFG unconditional path)
        if z is None:
            z = self.z_empty.expand(B, -1)

        # Flatten actions
        x = x.view(B, -1)

        # Process inputs
        state = cond["state"].view(B, -1)
        if hasattr(self, "cond_mlp"):
            state = self.cond_mlp(state)

        # Process z for cross-attention
        z_processed = self.z_processor(z)
        z_seq = z_processed.unsqueeze(1)  # (B, 1, z_dim) - single token

        # Time embedding
        time = time.view(B, 1)
        time_emb = self.time_embedding(time).view(B, self.time_dim)

        # Concatenate all inputs including z (direct path)
        x = torch.cat([x, time_emb, state, z_processed], dim=-1)

        # Input projection
        x = self.input_proj(x)

        # Forward through residual blocks with cross-attention
        for i, block in enumerate(self.blocks):
            residual = x

            # Linear transformation
            x = block(x)

            # LayerNorm after linear (if enabled)
            if self.use_layernorm and i < len(self.block_norms):
                x = self.block_norms[i](x)

            # Apply cross-attention if this layer has it
            if str(i) in self.cross_attn_layers:
                attn_module = self.cross_attn_layers[str(i)]
                attn_residual = x

                # Project query, key, value
                q = attn_module['query_proj'](x).unsqueeze(1)  # (B, 1, attn_dim)
                k = attn_module['key_proj'](z_seq)  # (B, 1, attn_dim)
                v = attn_module['value_proj'](z_seq)  # (B, 1, attn_dim)

                # Apply cross-attention
                attn_out, _ = attn_module['attention'](q, k, v)

                # Project back and squeeze
                attn_out = attn_module['out_proj'](attn_out.squeeze(1))

                # LayerNorm for attention output
                if str(i) in self.attn_norms:
                    attn_out = self.attn_norms[str(i)](attn_out)

                # Residual connection for attention
                x = attn_residual + attn_out

            # Activation
            if i < len(self.blocks) - 1:  # No activation after last block
                x = self.activation(x)

            # Residual connection for the block (if dimensions match)
            if self.residual_style and residual.shape == x.shape:
                x = x + residual

        # Output
        out = self.output_proj(x)
        return out.view(B, Ta, Da)


class AdaLNDiffusionMLP(nn.Module):
    """
    Diffusion MLP with Adaptive Layer Normalization conditioning on z.

    Similar to DiT (Diffusion Transformers), uses z to predict layer norm
    parameters, providing strong conditioning throughout the network.
    """

    def __init__(
        self,
        action_dim,
        horizon_steps,
        cond_dim,
        z_dim,
        time_dim=16,
        mlp_dims=[512, 512, 512],
        cond_mlp_dims=None,
        z_mlp_dims=[64, 256],  # Larger for predicting all LN params
        activation_type="Mish",
        out_activation_type="Identity",
        residual_style=False,
    ):
        super().__init__()
        output_dim = action_dim * horizon_steps

        # Time embedding
        self.time_embedding = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 2),
            nn.Mish(),
            nn.Linear(time_dim * 2, time_dim),
        )

        # Process z for AdaLN
        self.z_processor = MLP(
            [z_dim + time_dim] + z_mlp_dims,  # Combine z and time
            activation_type=activation_type,
            out_activation_type="Identity",
        )
        z_processed_dim = z_mlp_dims[-1]

        # Optional MLP for conditioning state
        if cond_mlp_dims is not None:
            self.cond_mlp = MLP(
                [cond_dim] + cond_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            cond_processed_dim = cond_mlp_dims[-1]
        else:
            cond_processed_dim = cond_dim

        # Main MLP with AdaLN
        input_dim = action_dim * horizon_steps + cond_processed_dim

        # Build network layers
        self.layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        self.adaln_projections = nn.ModuleList()

        dims = [input_dim] + mlp_dims
        for i in range(len(dims) - 1):
            self.layers.append(nn.Linear(dims[i], dims[i + 1]))

            if i < len(dims) - 2:  # No LN after last layer
                self.layer_norms.append(nn.LayerNorm(dims[i + 1], elementwise_affine=False))
                # Project z to scale and shift for this layer
                self.adaln_projections.append(
                    nn.Linear(z_processed_dim, 2 * dims[i + 1])
                )

            if i < len(dims) - 2:  # No activation after last layer
                if activation_type == "Mish":
                    self.layers.append(nn.Mish())
                elif activation_type == "ReLU":
                    self.layers.append(nn.ReLU())
                elif activation_type == "SiLU":
                    self.layers.append(nn.SiLU())

        # Output projection
        self.output_proj = nn.Linear(mlp_dims[-1], output_dim)

        self.time_dim = time_dim
        self.z_dim = z_dim
        self.residual_style = residual_style

        # Register z_empty as a buffer for CFG (initialized to zeros)
        self.register_buffer('z_empty', torch.zeros(1, z_dim))

    def forward(self, x, time, cond, z=None, **kwargs):
        """
        x: (B, Ta, Da) - noisy actions
        time: (B,) or int - diffusion step
        cond: dict with key state
        z: (B, z_dim) - trajectory embedding (None uses z_empty for unconditional)
        """
        B, Ta, Da = x.shape

        # Use z_empty if z is None (for CFG unconditional path)
        if z is None:
            z = self.z_empty.expand(B, -1)

        # Flatten actions
        x = x.view(B, -1)

        # Process inputs
        state = cond["state"].view(B, -1)
        if hasattr(self, "cond_mlp"):
            state = self.cond_mlp(state)

        # Time embedding
        time = time.view(B, 1)
        time_emb = self.time_embedding(time).view(B, self.time_dim)

        # Process z with time for AdaLN
        z_time = torch.cat([z, time_emb], dim=-1)
        z_processed = self.z_processor(z_time)

        # Concatenate main inputs (not z or time - they go through AdaLN)
        x = torch.cat([x, state], dim=-1)

        # Forward through layers with AdaLN
        layer_idx = 0
        adaln_idx = 0
        for i, module in enumerate(self.layers):
            if isinstance(module, nn.Linear):
                # Save input for residual if needed
                if self.residual_style and layer_idx > 0:
                    residual = x

                x = module(x)

                # Apply AdaLN if not the last linear layer
                if layer_idx < len(self.layer_norms):
                    # Get scale and shift from z
                    scale_shift = self.adaln_projections[adaln_idx](z_processed)
                    scale, shift = scale_shift.chunk(2, dim=-1)

                    # Apply layer norm then modulate
                    x = self.layer_norms[adaln_idx](x)
                    x = x * (1 + scale) + shift
                    adaln_idx += 1

                # Add residual if applicable
                if self.residual_style and layer_idx > 0 and residual.shape == x.shape:
                    x = x + residual

                layer_idx += 1
            else:
                # Activation function
                x = module(x)

        # Output
        out = self.output_proj(x)
        return out.view(B, Ta, Da)


# Model registry for easy selection
PARAMETERIZED_MODELS = {
    "concat": "ParameterizedDiffusionMLP",  # Original concatenation
    "film": "FiLMDiffusionMLP",
    "cross_attention": "CrossAttentionDiffusionMLP",
    "adaln": "AdaLNDiffusionMLP",
}