"""
Parameterized Behavior Cloning (PBC) Model.

Same FiLM architecture as pdp-film but with direct action reconstruction.
- Input: s0 (initial state), z (trajectory embedding from pretrained encoder)
- Output: Full action trajectory (horizon_steps x action_dim)
- Loss: MSE between predicted and ground truth trajectory

This combines the FiLM conditioning mechanism from parameterized diffusion
with the simple regression objective from behavior cloning.
"""

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

from model.common.mlp import MLP
from model.diffusion.parameterized_diffusion import FiLMLayer

log = logging.getLogger(__name__)

Sample = namedtuple("Sample", "trajectories chains")


class FiLMBCNetwork(nn.Module):
    """
    FiLM-based BC network that maps (state, z) to full trajectory.

    Architecture mirrors FiLMDiffusionMLP but without:
    - Time embedding (no diffusion timestep)
    - Noisy action input (directly outputs actions)

    Args:
        obs_dim: Observation/state dimension
        action_dim: Action dimension
        horizon_steps: Length of output trajectory
        z_dim: Dimension of trajectory embedding z
        mlp_dims: List of hidden layer dimensions (default: [768,768,768,768,768])
        cond_mlp_dims: Optional MLP dims for processing state
        z_mlp_dims: MLP dims for processing z before FiLM (default: [16, 32])
        activation_type: Activation function type
        out_activation_type: Output activation (Tanh for normalized actions)
        use_layernorm: Whether to use layer normalization
        residual_style: Whether to use residual connections
        film_at_layers: Which layers to apply FiLM ("all", "middle", or list)
    """

    def __init__(
        self,
        obs_dim,
        action_dim,
        horizon_steps,
        z_dim,
        mlp_dims=[768, 768, 768, 768, 768],
        cond_mlp_dims=None,
        z_mlp_dims=[16, 32],
        activation_type="Mish",
        out_activation_type="Tanh",
        use_layernorm=True,
        residual_style=True,
        film_at_layers="all",
    ):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.horizon_steps = horizon_steps
        self.z_dim = z_dim
        output_dim = action_dim * horizon_steps

        # Process z for FiLM conditioning
        if z_mlp_dims is not None:
            self.z_processor = MLP(
                [z_dim] + z_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            z_processed_dim = z_mlp_dims[-1]
        else:
            self.z_processor = nn.Identity()
            z_processed_dim = z_dim

        # Optional MLP for conditioning state
        if cond_mlp_dims is not None:
            self.cond_mlp = MLP(
                [obs_dim] + cond_mlp_dims,
                activation_type=activation_type,
                out_activation_type="Identity",
            )
            cond_processed_dim = cond_mlp_dims[-1]
        else:
            cond_processed_dim = obs_dim

        # Input: concatenate [state, z_processed]
        # (NO time_emb, NO noisy actions unlike diffusion)
        input_dim = cond_processed_dim + z_processed_dim

        # Determine which layers get FiLM
        num_layers = len(mlp_dims)
        if film_at_layers == "all":
            film_indices = list(range(num_layers))
        elif film_at_layers == "middle":
            film_indices = [num_layers // 2]
        else:
            film_indices = film_at_layers

        # Build residual blocks with FiLM
        self.input_proj = nn.Linear(input_dim, mlp_dims[0])

        self.blocks = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        self.film_layers = nn.ModuleDict()
        self.residual_style = residual_style
        self.use_layernorm = use_layernorm

        # Get activation function
        if activation_type == "Mish":
            self.activation = nn.Mish()
        elif activation_type == "ReLU":
            self.activation = nn.ReLU()
        elif activation_type == "SiLU":
            self.activation = nn.SiLU()
        else:
            self.activation = nn.Mish()

        for i in range(num_layers):
            in_dim = mlp_dims[i - 1] if i > 0 else mlp_dims[0]
            out_dim = mlp_dims[i]

            self.blocks.append(nn.Linear(in_dim, out_dim))

            if use_layernorm:
                self.layer_norms.append(nn.LayerNorm(out_dim))

            if i in film_indices:
                self.film_layers[str(i)] = FiLMLayer(out_dim, z_processed_dim)

        # Output projection
        self.output_proj = nn.Linear(mlp_dims[-1], output_dim)

        # Output activation (Tanh for normalized actions)
        if out_activation_type == "Tanh":
            self.out_activation = nn.Tanh()
        else:
            self.out_activation = nn.Identity()

        self.z_processed_dim = z_processed_dim

        # Register z_empty buffer (for inference without z)
        self.register_buffer("z_empty", torch.zeros(1, z_dim))

    def forward(self, state, z=None):
        """
        Args:
            state: (B, obs_dim) - initial state s0
            z: (B, z_dim) - trajectory embedding (None uses z_empty)

        Returns:
            trajectory: (B, horizon_steps, action_dim)
        """
        B = state.shape[0]

        # Use z_empty if z is None
        if z is None:
            z = self.z_empty.expand(B, -1)

        # Flatten state if needed
        state = state.view(B, -1)

        # Process state
        if hasattr(self, "cond_mlp"):
            state = self.cond_mlp(state)

        # Process z for FiLM
        z_processed = self.z_processor(z)

        # Concatenate inputs (state + z_processed)
        x = torch.cat([state, z_processed], dim=-1)

        # Input projection
        x = self.input_proj(x)

        # Forward through residual blocks with FiLM modulation
        for i, block in enumerate(self.blocks):
            residual = x

            # Linear transformation
            x = block(x)

            # LayerNorm before FiLM (if enabled)
            if self.use_layernorm and i < len(self.layer_norms):
                x = self.layer_norms[i](x)

            # Apply FiLM modulation
            if str(i) in self.film_layers:
                x = self.film_layers[str(i)](x, z_processed)

            # Activation
            if i < len(self.blocks) - 1:  # No activation after last block
                x = self.activation(x)

            # Residual connection (if dimensions match)
            if self.residual_style and residual.shape == x.shape:
                x = x + residual

        # Output
        out = self.output_proj(x)
        out = self.out_activation(out)

        return out.view(B, self.horizon_steps, self.action_dim)


class ParameterizedBCModel(nn.Module):
    """
    Parameterized BC Model wrapper.

    Wraps FiLMBCNetwork and provides training interface.
    """

    def __init__(
        self,
        network,
        horizon_steps,
        obs_dim,
        action_dim,
        network_path=None,
        device="cuda:0",
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.network = network.to(device)

        if network_path is not None:
            checkpoint = torch.load(
                network_path, map_location=device, weights_only=True
            )
            if "ema" in checkpoint:
                self.load_state_dict(checkpoint["ema"], strict=False)
                logging.info("Loaded PBC model from %s (ema weights)", network_path)
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded PBC model from %s", network_path)

        log.info(
            f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
        )

    def loss(self, true_action, cond, z_embedding):
        """
        Compute MSE loss between predicted and ground truth trajectory.

        Args:
            true_action: (B, horizon_steps, action_dim) - ground truth trajectory
            cond: dict with 'state' key containing (B, cond_steps, obs_dim)
            z_embedding: (B, z_dim) - trajectory embedding

        Returns:
            loss: scalar MSE loss
        """
        # Get initial state
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]  # Use most recent state: (B, obs_dim)

        # Predict trajectory
        pred_trajectory = self.network(state, z_embedding)

        # MSE loss
        loss = F.mse_loss(pred_trajectory, true_action, reduction="mean")
        return loss

    @torch.no_grad()
    def forward(self, cond, z=None, deterministic=True):
        """
        Forward pass for inference.

        Args:
            cond: dict with 'state' key
            z: (B, z_dim) trajectory embedding (optional)
            deterministic: bool (unused, always deterministic for BC)

        Returns:
            Sample: namedtuple with trajectories field
        """
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]

        trajectory = self.network(state, z)
        return Sample(trajectory, None)
