"""
Behavior Cloning (BC) Model.

Simple MLP that predicts full action trajectory from initial state.
- Input: s0 (initial state, obs_dim)
- Output: Full action trajectory (horizon_steps x action_dim)
- Loss: MSE between predicted and ground truth trajectory

This is an open-loop baseline that predicts the entire trajectory at once.
"""

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

from model.common.mlp import MLP, ResidualMLP

log = logging.getLogger(__name__)

Sample = namedtuple("Sample", "trajectories chains")


class BCNetwork(nn.Module):
    """
    MLP network for BC that maps initial state to full trajectory.

    Architecture:
        state (obs_dim) -> MLP -> trajectory (horizon_steps * action_dim)
    """

    def __init__(
        self,
        obs_dim,
        action_dim,
        horizon_steps,
        mlp_dims=[512, 512, 512],
        activation_type="Mish",
        residual_style=True,
        use_layernorm=False,
        dropout=0.0,
    ):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.horizon_steps = horizon_steps
        self.output_dim = horizon_steps * action_dim

        input_dim = obs_dim
        output_dim = self.output_dim

        if residual_style:
            self.network = ResidualMLP(
                dim_list=[input_dim] + mlp_dims + [output_dim],
                activation_type=activation_type,
                out_activation_type="Tanh",  # Actions are normalized to [-1, 1]
                use_layernorm=use_layernorm,
            )
        else:
            self.network = MLP(
                dim_list=[input_dim] + mlp_dims + [output_dim],
                activation_type=activation_type,
                out_activation_type="Tanh",
                use_layernorm=use_layernorm,
                dropout=dropout,
            )

    def forward(self, state):
        """
        Args:
            state: (B, obs_dim) - initial state s0

        Returns:
            trajectory: (B, horizon_steps, action_dim) - predicted action trajectory
        """
        B = state.shape[0]
        output = self.network(state)
        trajectory = output.view(B, self.horizon_steps, self.action_dim)
        return trajectory


class BCModel(nn.Module):
    """
    Behavior Cloning Model.

    Wraps BCNetwork and provides training interface compatible with the codebase.
    """

    def __init__(
        self,
        network,
        horizon_steps,
        obs_dim,
        action_dim,
        network_path=None,
        device="cuda:0",
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.network = network.to(device)

        if network_path is not None:
            checkpoint = torch.load(
                network_path, map_location=device, weights_only=True
            )
            if "ema" in checkpoint:
                self.load_state_dict(checkpoint["ema"], strict=False)
                logging.info("Loaded BC model from %s (ema weights)", network_path)
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded BC model from %s", network_path)

        log.info(
            f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
        )

    def loss(self, true_action, cond):
        """
        Compute MSE loss between predicted and ground truth trajectory.

        Args:
            true_action: (B, horizon_steps, action_dim) - ground truth trajectory
            cond: dict with 'state' key containing (B, cond_steps, obs_dim)

        Returns:
            loss: scalar MSE loss
        """
        # Get initial state (most recent observation)
        state = cond["state"]  # (B, cond_steps, obs_dim)
        if state.dim() == 3:
            state = state[:, -1, :]  # Use most recent state: (B, obs_dim)

        # Predict trajectory
        pred_trajectory = self.network(state)  # (B, horizon_steps, action_dim)

        # MSE loss
        loss = F.mse_loss(pred_trajectory, true_action, reduction="mean")
        return loss

    @torch.no_grad()
    def forward(self, cond, deterministic=True):
        """
        Forward pass for inference.

        Args:
            cond: dict with 'state' key containing initial state(s)
            deterministic: bool (unused for BC, always deterministic)

        Returns:
            Sample: namedtuple with trajectories field containing predicted actions
        """
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]  # Use most recent state

        trajectory = self.network(state)
        return Sample(trajectory, None)
