"""
Behavior Cloning with Gaussian Mixture Model (BC-GMM) / Mixture Density Network (MDN).

Predicts a GMM distribution over full action trajectories from initial state.
- Input: s0 (initial state, obs_dim)
- Output: GMM parameters (means, log_vars, mixing weights) for K mixture components
- Each component predicts full trajectory (horizon_steps x action_dim)
- Loss: Negative log-likelihood of GMM

This captures multi-modal behavior in demonstrations.
"""

import logging
import torch
import torch.nn as nn
import torch.distributions as D
from collections import namedtuple

from model.common.mlp import MLP, ResidualMLP

log = logging.getLogger(__name__)

Sample = namedtuple("Sample", "trajectories chains")


class BCGMMNetwork(nn.Module):
    """
    MLP network for BC-GMM that maps initial state to GMM parameters.

    Architecture:
        state (obs_dim) -> shared MLP -> GMM parameters
        - means: (K, horizon_steps * action_dim)
        - scales: (K, horizon_steps * action_dim)
        - logits: (K,) for mixture weights
    """

    def __init__(
        self,
        obs_dim,
        action_dim,
        horizon_steps,
        num_modes=5,
        mlp_dims=[512, 512, 512],
        activation_type="Mish",
        residual_style=True,
        use_layernorm=False,
        fixed_std=None,
        learn_fixed_std=False,
        std_min=0.01,
        std_max=1.0,
    ):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.horizon_steps = horizon_steps
        self.num_modes = num_modes
        self.trajectory_dim = horizon_steps * action_dim

        input_dim = obs_dim
        output_dim_mean = self.trajectory_dim * num_modes

        if residual_style:
            model = ResidualMLP
        else:
            model = MLP

        # Mean network - outputs means for all K components
        self.mlp_mean = model(
            dim_list=[input_dim] + mlp_dims + [output_dim_mean],
            activation_type=activation_type,
            out_activation_type="Identity",
            use_layernorm=use_layernorm,
        )

        # Variance handling
        self.use_fixed_std = fixed_std is not None
        self.fixed_std = fixed_std
        self.learn_fixed_std = learn_fixed_std

        if fixed_std is None:
            # Learnable variance per component
            self.mlp_logvar = model(
                dim_list=[input_dim] + mlp_dims + [output_dim_mean],
                activation_type=activation_type,
                out_activation_type="Identity",
                use_layernorm=use_layernorm,
            )
        elif learn_fixed_std:
            # Single learnable std shared across trajectory
            self.logvar = nn.Parameter(
                torch.log(torch.tensor([fixed_std**2] * (action_dim * num_modes))),
                requires_grad=True,
            )

        # Bounds for log variance
        self.register_buffer("logvar_min", torch.log(torch.tensor(std_min**2)))
        self.register_buffer("logvar_max", torch.log(torch.tensor(std_max**2)))

        # Mixture weights network
        self.mlp_weights = model(
            dim_list=[input_dim] + mlp_dims + [num_modes],
            activation_type=activation_type,
            out_activation_type="Identity",
            use_layernorm=use_layernorm,
        )

    def forward(self, state):
        """
        Args:
            state: (B, obs_dim) - initial state s0

        Returns:
            means: (B, K, trajectory_dim) - mean trajectories for each component
            scales: (B, K, trajectory_dim) - standard deviations
            logits: (B, K) - mixture logits (unnormalized)
        """
        B = state.shape[0]
        device = state.device

        # Compute means and apply tanh for bounded actions
        out_mean = self.mlp_mean(state)
        out_mean = torch.tanh(out_mean)  # Bound to [-1, 1]
        out_mean = out_mean.view(B, self.num_modes, self.trajectory_dim)

        # Compute scales
        if self.learn_fixed_std:
            out_logvar = torch.clamp(self.logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)
            out_scale = out_scale.view(1, self.num_modes, self.action_dim)
            out_scale = out_scale.expand(B, -1, -1)
            out_scale = out_scale.repeat(1, 1, self.horizon_steps)
        elif self.use_fixed_std:
            out_scale = torch.ones_like(out_mean, device=device) * self.fixed_std
        else:
            out_logvar = self.mlp_logvar(state)
            out_logvar = out_logvar.view(B, self.num_modes, self.trajectory_dim)
            out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max)
            out_scale = torch.exp(0.5 * out_logvar)

        # Compute mixture weights (logits)
        out_weights = self.mlp_weights(state)
        out_weights = out_weights.view(B, self.num_modes)

        return out_mean, out_scale, out_weights


class BCGMMModel(nn.Module):
    """
    BC-GMM Model.

    Wraps BCGMMNetwork and provides training interface compatible with the codebase.
    """

    def __init__(
        self,
        network,
        horizon_steps,
        obs_dim,
        action_dim,
        network_path=None,
        device="cuda:0",
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        self.network = network.to(device)

        if network_path is not None:
            checkpoint = torch.load(
                network_path, map_location=device, weights_only=True
            )
            if "ema" in checkpoint:
                self.load_state_dict(checkpoint["ema"], strict=False)
                logging.info("Loaded BC-GMM model from %s (ema weights)", network_path)
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded BC-GMM model from %s", network_path)

        log.info(
            f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
        )

    def get_distribution(self, cond, deterministic=False):
        """
        Build GMM distribution from network output.

        Args:
            cond: dict with 'state' key
            deterministic: if True, use very low variance

        Returns:
            dist: MixtureSameFamily distribution
            entropy: approximate entropy
            std: mean std across components
        """
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]

        means, scales, logits = self.network(state)

        if deterministic:
            scales = torch.ones_like(means) * 1e-4

        # Build GMM distribution
        # component_distribution: Normal with shape (B, K, trajectory_dim)
        component_distribution = D.Normal(loc=means, scale=scales)
        component_distribution = D.Independent(component_distribution, 1)

        # Approximate entropy (exact for GMM is intractable)
        component_entropy = component_distribution.entropy()
        weights = logits.softmax(-1)
        approx_entropy = torch.mean(torch.sum(weights * component_entropy, dim=-1))
        mean_std = torch.mean(torch.sum(weights * scales.mean(-1), dim=-1))

        # Mixture distribution
        mixture_distribution = D.Categorical(logits=logits)
        dist = D.MixtureSameFamily(
            mixture_distribution=mixture_distribution,
            component_distribution=component_distribution,
        )

        return dist, approx_entropy, mean_std

    def loss(self, true_action, cond):
        """
        Compute negative log-likelihood loss for GMM.

        Args:
            true_action: (B, horizon_steps, action_dim) - ground truth trajectory
            cond: dict with 'state' key

        Returns:
            loss: scalar NLL loss
        """
        B = true_action.shape[0]

        # Flatten trajectory
        true_action_flat = true_action.view(B, -1)  # (B, trajectory_dim)

        # Get distribution
        dist, entropy, _ = self.get_distribution(cond, deterministic=False)

        # Negative log-likelihood
        nll = -dist.log_prob(true_action_flat)  # (B,)
        loss = nll.mean()

        return loss

    @torch.no_grad()
    def forward(self, cond, deterministic=False, fixed_component=None):
        """
        Forward pass for inference.

        Args:
            cond: dict with 'state' key
            deterministic: if True, use mode with highest weight and low variance
            fixed_component: int, if provided, always use this component (for posterior selection)

        Returns:
            Sample: namedtuple with trajectories field
        """
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]

        B = state.shape[0]

        means, scales, logits = self.network(state)

        if fixed_component is not None:
            # Use specified component (from posterior selection)
            trajectory = means[:, fixed_component]  # (B, trajectory_dim)
        elif deterministic:
            # Use mean of highest-weighted component
            weights = logits.softmax(-1)  # (B, K)
            best_idx = weights.argmax(dim=-1)  # (B,)
            trajectory = means[torch.arange(B), best_idx]  # (B, trajectory_dim)
        else:
            # Sample from GMM
            dist, _, _ = self.get_distribution(cond, deterministic=False)
            trajectory = dist.sample()  # (B, trajectory_dim)

        trajectory = trajectory.view(B, self.horizon_steps, self.action_dim)
        return Sample(trajectory, None)
