"""
Implicit Behavior Cloning (IBC) with Energy-Based Model.

Energy network E(s0, trajectory) -> scalar energy.
- Input: s0 (initial state) concatenated with trajectory
- Output: scalar energy value
- Training: InfoNCE/contrastive loss with negative samples
- Inference: Langevin dynamics or gradient descent to find low-energy trajectory

Reference:
    Florence et al., "Implicit Behavioral Cloning", CoRL 2021
    https://arxiv.org/abs/2109.00137
"""

import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

from model.common.mlp import MLP, ResidualMLP

log = logging.getLogger(__name__)

Sample = namedtuple("Sample", "trajectories chains")


class IBCEnergyNetwork(nn.Module):
    """
    Energy network for IBC.

    Architecture:
        concat(state, trajectory) -> MLP -> scalar energy

    Lower energy = better trajectory for given state.
    """

    def __init__(
        self,
        obs_dim,
        action_dim,
        horizon_steps,
        mlp_dims=[512, 512, 512],
        activation_type="Mish",
        residual_style=True,
        use_layernorm=False,
        dropout=0.0,
    ):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.horizon_steps = horizon_steps
        self.trajectory_dim = horizon_steps * action_dim

        input_dim = obs_dim + self.trajectory_dim
        output_dim = 1  # Scalar energy

        if residual_style:
            self.network = ResidualMLP(
                dim_list=[input_dim] + mlp_dims + [output_dim],
                activation_type=activation_type,
                out_activation_type="Identity",
                use_layernorm=use_layernorm,
            )
        else:
            self.network = MLP(
                dim_list=[input_dim] + mlp_dims + [output_dim],
                activation_type=activation_type,
                out_activation_type="Identity",
                use_layernorm=use_layernorm,
                dropout=dropout,
            )

    def forward(self, state, trajectory):
        """
        Compute energy for (state, trajectory) pairs.

        Args:
            state: (B, obs_dim) or (B, N, obs_dim) - initial state(s)
            trajectory: (B, trajectory_dim) or (B, N, trajectory_dim) - trajectory/trajectories

        Returns:
            energy: (B,) or (B, N) - scalar energy values
        """
        # Handle batched samples for contrastive learning
        if trajectory.dim() == 3:
            # (B, N, trajectory_dim) where N = 1 positive + K negatives
            B, N, _ = trajectory.shape
            if state.dim() == 2:
                state = state.unsqueeze(1).expand(-1, N, -1)  # (B, N, obs_dim)
            x = torch.cat([state, trajectory], dim=-1)  # (B, N, input_dim)
            x = x.view(B * N, -1)
            energy = self.network(x)
            energy = energy.view(B, N)
        else:
            # Standard case: (B, trajectory_dim)
            x = torch.cat([state, trajectory], dim=-1)  # (B, input_dim)
            energy = self.network(x).squeeze(-1)  # (B,)

        return energy


class IBCModel(nn.Module):
    """
    Implicit Behavior Cloning Model.

    Uses an energy-based model with InfoNCE loss for training
    and Langevin dynamics for inference.
    """

    def __init__(
        self,
        network,
        horizon_steps,
        obs_dim,
        action_dim,
        network_path=None,
        device="cuda:0",
        # Training parameters
        num_negative_samples=256,
        # Langevin dynamics parameters
        langevin_steps=100,
        langevin_step_size=0.01,
        langevin_noise_scale=0.01,
        langevin_clip_value=1.0,
        # Optimization-based inference
        use_gradient_descent=False,
        gd_steps=100,
        gd_lr=0.1,
        # Temperature for InfoNCE
        temperature=1.0,
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.trajectory_dim = horizon_steps * action_dim

        self.network = network.to(device)

        # Training params
        self.num_negative_samples = num_negative_samples
        self.temperature = temperature

        # Langevin dynamics params
        self.langevin_steps = langevin_steps
        self.langevin_step_size = langevin_step_size
        self.langevin_noise_scale = langevin_noise_scale
        self.langevin_clip_value = langevin_clip_value

        # Gradient descent params
        self.use_gradient_descent = use_gradient_descent
        self.gd_steps = gd_steps
        self.gd_lr = gd_lr

        if network_path is not None:
            checkpoint = torch.load(
                network_path, map_location=device, weights_only=True
            )
            if "ema" in checkpoint:
                self.load_state_dict(checkpoint["ema"], strict=False)
                logging.info("Loaded IBC model from %s (ema weights)", network_path)
            else:
                self.load_state_dict(checkpoint["model"], strict=False)
                logging.info("Loaded IBC model from %s", network_path)

        log.info(
            f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
        )

    def _generate_negative_samples(self, B, num_samples):
        """
        Generate negative trajectory samples.

        Args:
            B: batch size
            num_samples: number of negative samples per batch item

        Returns:
            negatives: (B, num_samples, trajectory_dim) - uniform random in [-1, 1]
        """
        negatives = torch.rand(
            B, num_samples, self.trajectory_dim, device=self.device
        ) * 2 - 1
        return negatives

    def loss(self, true_action, cond):
        """
        Compute InfoNCE contrastive loss.

        The loss encourages low energy for positive (ground truth) trajectories
        and high energy for negative (random) trajectories.

        InfoNCE: -log( exp(-E(s, a_pos) / tau) / sum_i exp(-E(s, a_i) / tau) )

        Args:
            true_action: (B, horizon_steps, action_dim) - ground truth trajectory
            cond: dict with 'state' key

        Returns:
            loss: scalar InfoNCE loss
        """
        B = true_action.shape[0]

        # Get state
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]  # (B, obs_dim)

        # Flatten positive trajectory
        positive = true_action.view(B, self.trajectory_dim)  # (B, trajectory_dim)

        # Generate negative samples
        negatives = self._generate_negative_samples(B, self.num_negative_samples)
        # (B, num_negative_samples, trajectory_dim)

        # Combine positive and negatives: positive is first
        all_trajectories = torch.cat(
            [positive.unsqueeze(1), negatives], dim=1
        )  # (B, 1 + K, trajectory_dim)

        # Compute energies
        energies = self.network(state, all_trajectories)  # (B, 1 + K)

        # InfoNCE loss (cross-entropy where positive is class 0)
        # Lower energy = higher probability
        logits = -energies / self.temperature  # (B, 1 + K)
        labels = torch.zeros(B, dtype=torch.long, device=self.device)  # positive is index 0

        loss = F.cross_entropy(logits, labels)

        return loss

    def _langevin_dynamics(self, state, initial_trajectory=None):
        """
        Use Langevin dynamics to find low-energy trajectory.

        x_{t+1} = x_t - step_size * grad_x E(s, x) + noise

        Args:
            state: (B, obs_dim)
            initial_trajectory: optional starting point

        Returns:
            trajectory: (B, trajectory_dim) - optimized trajectory
        """
        B = state.shape[0]

        if initial_trajectory is None:
            # Start from random initialization
            trajectory = torch.rand(
                B, self.trajectory_dim, device=self.device
            ) * 2 - 1
        else:
            trajectory = initial_trajectory.clone()

        trajectory.requires_grad_(True)

        for _ in range(self.langevin_steps):
            # Compute energy and gradient
            energy = self.network(state, trajectory)  # (B,)
            grad = torch.autograd.grad(energy.sum(), trajectory)[0]

            # Langevin update
            noise = torch.randn_like(trajectory) * self.langevin_noise_scale
            trajectory = trajectory - self.langevin_step_size * grad + noise

            # Clip to valid range
            trajectory = trajectory.clamp(-self.langevin_clip_value, self.langevin_clip_value)
            trajectory = trajectory.detach().requires_grad_(True)

        return trajectory.detach()

    def _gradient_descent(self, state, initial_trajectory=None):
        """
        Use gradient descent to find low-energy trajectory.

        Args:
            state: (B, obs_dim)
            initial_trajectory: optional starting point

        Returns:
            trajectory: (B, trajectory_dim) - optimized trajectory
        """
        B = state.shape[0]

        if initial_trajectory is None:
            trajectory = torch.rand(
                B, self.trajectory_dim, device=self.device
            ) * 2 - 1
        else:
            trajectory = initial_trajectory.clone()

        trajectory = nn.Parameter(trajectory)
        optimizer = torch.optim.Adam([trajectory], lr=self.gd_lr)

        for _ in range(self.gd_steps):
            optimizer.zero_grad()
            energy = self.network(state, trajectory)
            energy.sum().backward()
            optimizer.step()

            # Clip to valid range
            with torch.no_grad():
                trajectory.data.clamp_(-self.langevin_clip_value, self.langevin_clip_value)

        return trajectory.detach()

    def _sample_and_select_best(self, state, num_candidates=512):
        """
        Sample multiple trajectories and select the one with lowest energy.

        This is an alternative to Langevin dynamics that's sometimes more robust.

        Args:
            state: (B, obs_dim)
            num_candidates: number of random candidates to evaluate

        Returns:
            trajectory: (B, trajectory_dim) - best trajectory
        """
        B = state.shape[0]

        # Generate random candidates
        candidates = torch.rand(
            B, num_candidates, self.trajectory_dim, device=self.device
        ) * 2 - 1

        # Compute energies
        with torch.no_grad():
            energies = self.network(state, candidates)  # (B, num_candidates)

        # Select best (lowest energy)
        best_idx = energies.argmin(dim=-1)  # (B,)
        trajectory = candidates[torch.arange(B, device=self.device), best_idx]

        return trajectory

    @torch.no_grad()
    def forward(self, cond, deterministic=True, demo_trajectory=None, demo_noise_std=0.1):
        """
        Forward pass for inference using energy minimization.

        Args:
            cond: dict with 'state' key
            deterministic: if True, use more Langevin steps for better convergence
            demo_trajectory: optional (1, horizon_steps, action_dim) demo trajectory
                            If provided, initialize optimization near demo instead of random.
                            This is the "demo-initialized inference" approach for adaptation.
            demo_noise_std: std of noise to add when initializing from demo (default 0.1)

        Returns:
            Sample: namedtuple with trajectories field
        """
        state = cond["state"]
        if state.dim() == 3:
            state = state[:, -1, :]

        B = state.shape[0]

        # Prepare initial trajectory
        initial_trajectory = None
        if demo_trajectory is not None:
            # Demo-initialized inference: start from demo + noise
            # This restricts the optimization search region to be near the demo
            # demo_trajectory shape: (1, horizon_steps, action_dim) -> flatten and expand
            demo_flat = demo_trajectory.view(1, self.trajectory_dim)  # (1, trajectory_dim)
            demo_flat = demo_flat.expand(B, -1)  # (B, trajectory_dim)
            # Add Gaussian noise to create diverse starting points
            noise = torch.randn_like(demo_flat) * demo_noise_std
            initial_trajectory = (demo_flat + noise).clamp(-1.0, 1.0)

        # Enable gradients for Langevin dynamics
        with torch.enable_grad():
            if self.use_gradient_descent:
                if initial_trajectory is None:
                    # First, sample and select best as initialization
                    initial_trajectory = self._sample_and_select_best(state, num_candidates=256)
                trajectory = self._gradient_descent(state, initial_trajectory=initial_trajectory)
            else:
                # Langevin dynamics with random restart (or demo-initialized)
                # Run multiple chains and select best
                num_chains = 4
                best_energy = float('inf')
                best_trajectory = None

                for chain_idx in range(num_chains):
                    if initial_trajectory is not None:
                        # Demo-initialized: add different noise for each chain
                        noise = torch.randn(B, self.trajectory_dim, device=self.device) * demo_noise_std
                        demo_flat = demo_trajectory.view(1, self.trajectory_dim).expand(B, -1)
                        chain_init = (demo_flat + noise).clamp(-1.0, 1.0)
                    else:
                        chain_init = None

                    traj = self._langevin_dynamics(state, initial_trajectory=chain_init)
                    energy = self.network(state, traj).mean()

                    if energy < best_energy:
                        best_energy = energy
                        best_trajectory = traj

                trajectory = best_trajectory

        trajectory = trajectory.view(B, self.horizon_steps, self.action_dim)
        return Sample(trajectory, None)
