"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction."""

import math
from typing import Dict, List, Tuple

import cvxpy as cp
import numpy as np
import torch
import torch.nn as nn
from cvxpylayers.torch import CvxpyLayer
from diffusers.schedulers.scheduling_ddim import DDIMScheduler

from prismatic.vla.constants import (
    ACTION_DIM,
    ACTION_TOKEN_BEGIN_IDX,
    IGNORE_INDEX,
    NUM_ACTIONS_CHUNK,
    PROPRIO_DIM,
    STOP_INDEX,
)


class SinusoidalPositionalEncoding(nn.Module):
    """
    Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps.

    For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,)
    Then the output would be a batch of 32 timestep embeddings -> shape (32, D)

    Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim  # dimensionality of the positional encoding

    def forward(self, x):
        # x: (batch_size,)
        device = x.device
        assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}"
        half_dim = self.dim // 2
        exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1)  # shape: (D/2,)
        emb = torch.exp(exponent)  # shape: (D/2,)
        emb = x[:, None] * emb[None, :]  # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)  # shape: (batch_size, D)
        return emb


class MLPResNetBlock(nn.Module):
    """One MLP ResNet block with a residual connection."""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.ffn = nn.Sequential(  # feedforward network, similar to the ones in Transformers
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.ReLU(),
        )

    def forward(self, x):
        # x: (batch_size, hidden_dim)
        # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
        # described here: https://arxiv.org/pdf/2002.04745.pdf
        identity = x
        x = self.ffn(x)
        x = x + identity
        return x


class MLPResNet(nn.Module):
    """MLP with residual connection blocks."""

    def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.mlp_resnet_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: (batch_size, input_dim)
        x = self.layer_norm1(x)  # shape: (batch_size, input_dim)
        x = self.fc1(x)  # shape: (batch_size, hidden_dim)
        x = self.relu(x)  # shape: (batch_size, hidden_dim)
        for block in self.mlp_resnet_blocks:
            x = block(x)  # shape: (batch_size, hidden_dim)
        x = self.layer_norm2(x)  # shape: (batch_size, hidden_dim)
        x = self.fc2(x)  # shape: (batch_size, output_dim)
        return x


class L1RegressionActionHead(nn.Module):
    """Simple MLP-based action head that generates continuous actions via L1 regression."""

    def __init__(
        self,
        input_dim=4096,
        hidden_dim=4096,
        action_dim=7,
    ):
        super().__init__()
        self.action_dim = action_dim
        self.model = MLPResNet(
            num_blocks=2, input_dim=input_dim * ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
        )

    def predict_action(self, actions_hidden_states):
        # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
        # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
        # ground_truth_actions: ground-truth actions
        # - shape: (batch_size, chunk_len, action_dim)
        batch_size = actions_hidden_states.shape[0]
        rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
        action = self.model(rearranged_actions_hidden_states)
        return action


class NoisePredictionModel(nn.Module):
    """
    Diffusion noise prediction model that takes an observation embedding (which fuses the
    noisy action, diffusion timestep, and image-language observation embeddings) and
    outputs a noise prediction.
    """

    def __init__(
        self,
        transformer_hidden_dim,  # Transformer hidden embedding size
        hidden_dim,  # MLP hidden size
        action_dim=7,  # action dimensionality
    ):
        super().__init__()
        self.mlp_resnet = MLPResNet(
            num_blocks=2,
            input_dim=transformer_hidden_dim,
            hidden_dim=hidden_dim,
            output_dim=action_dim,
        )

    def forward(
        self,
        obs,
    ):
        # obs: observation embeddings to condition the generation on
        # - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim)
        #
        # output: predicted noise
        # - shape: (batch_size, action_dim)
        output = self.mlp_resnet(obs)
        return output


class DiffusionActionHead(nn.Module):
    """
    Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process.

    Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py
    """

    def __init__(
        self,
        input_dim=4096,
        hidden_dim=4096,
        action_dim=7,
        num_diffusion_steps_train=50,
    ):
        super().__init__()
        self.action_dim = action_dim
        self.noise_predictor = NoisePredictionModel(
            transformer_hidden_dim=hidden_dim * ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim
        )
        self.num_diffusion_steps_train = num_diffusion_steps_train
        self.noise_scheduler = DDIMScheduler(
            num_train_timesteps=num_diffusion_steps_train, beta_schedule="squaredcos_cap_v2"
        )
        self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim)

    def sample_noisy_actions(self, ground_truth_actions):
        """
        Samples noise and applies noise to ground-truth actions to produce noisy actions, which are
        used as input in the noise prediction network. Returns noise, noisy actions, and the
        corresponding diffusion timestep embeddings.
        """
        # ground_truth_actions: ground-truth actions
        # - shape: (batch_size, chunk_len, action_dim)
        batch_size = ground_truth_actions.shape[0]
        device = ground_truth_actions.device
        # Sample random noise with shape equal to actions, used for closed-form forward diffusion.
        noise = torch.randn(
            size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype
        )  # (B, chunk_len, action_dim)
        # Sample random diffusion timesteps (one for each action in batch).
        timesteps = torch.randint(
            low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device
        )
        # Add noise to clean actions according to the magnitude at each diffusion timestep via
        # closed-form forward diffusion.
        noisy_actions = self.noise_scheduler.add_noise(
            ground_truth_actions, noise, timesteps
        )  # (B, chunk_len, action_dim)

        # Get diffusion timestep embeddings as well
        diffusion_timestep_embeddings = (
            self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device)
        )  # (B, llm_dim)
        diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1)  # (B, 1, llm_dim)

        return_dict = dict(
            noise=noise,
            noisy_actions=noisy_actions,
            diffusion_timestep_embeddings=diffusion_timestep_embeddings,
        )

        return return_dict

    def predict_noise(self, actions_hidden_states):
        """
        Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings,
        noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions.
        """
        # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
        # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
        batch_size = actions_hidden_states.shape[0]
        device = actions_hidden_states.device
        rearranged_actions_hidden_states = actions_hidden_states.reshape(
            batch_size, NUM_ACTIONS_CHUNK, -1
        )  # (batch_size, chunk_len, action_dim * hidden_dim)
        # Get diffusion model's noise prediction.
        noise_pred = self.noise_predictor(rearranged_actions_hidden_states)
        return noise_pred


class LeastL1RegressionLayer(nn.Module):
    """CVX-based L1 regression layer for continuous actions."""

    def __init__(self, n: int, k: int, v: int):
        super().__init__()
        self.n, self.k, self.v = n, k, v

        # CVX optimization variables
        self.x = cp.Variable(k)  # coefficients
        self.r = cp.Variable(n * v)  # residuals

        # Problem parameters
        self.A_param = cp.Parameter((n * v, k))
        self.b_param = cp.Parameter(n * v)

        # L1 objective: minimize sum of absolute residuals
        objective = cp.Minimize(cp.sum(self.r))
        residual = self.A_param @ self.x - self.b_param
        constraints = [self.r >= residual, self.r >= -residual]

        problem = cp.Problem(objective, constraints)
        self.layer = CvxpyLayer(problem, parameters=[self.A_param, self.b_param], variables=[self.x])

    def forward(self, A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """Solve L1 regression problem.

        Args:
            A: Feature tensor, shape (n, k, v)
            b: Target tensor, shape (n, v)

        Returns:
            Optimal coefficients, shape (k,)
        """
        n, k, v = A.shape
        assert (n, k, v) == (self.n, self.k, self.v), (
            f"Shape mismatch: expected {(self.n, self.k, self.v)}, got {A.shape}"
        )
        assert b.shape == (n, v), f"Target shape mismatch: expected {(self.n, self.v)}, got {b.shape}"

        # Reshape for CVX: (n, k, v) → (nv, k) and (n, v) → (nv,)
        A_flat = A.permute(0, 2, 1).reshape(n * v, k)
        b_flat = b.reshape(n * v)

        (x_opt,) = self.layer(A_flat, b_flat)
        return x_opt


class FunctionEncoderActionHead(nn.Module):
    """Multi-dataset FE head with per-dataset L1/L2 coefficients."""

    def __init__(self, input_dim: int, hidden_dim: int, action_dim: int, k: int = 32, n_continuous: int = 6):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim
        self.k = k
        self.n_continuous = n_continuous
        self.n_discrete = action_dim - n_continuous

        self.model = MLPResNet(num_blocks=2, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=k * action_dim)

        # Per-dataset coefficients storage
        # Will be populated with dataset_name -> {"l1": tensor, "l2": tensor}
        self.dataset_coefficients: Dict[str, Dict[str, torch.Tensor]] = {}

    def set_dataset_coefficients(self, dataset_name: str, l1_coeffs: torch.Tensor, l2_coeffs: torch.Tensor):
        """Set coefficients for a specific dataset."""
        device = next(self.model.parameters()).device
        self.dataset_coefficients[dataset_name] = {
            "l1": l1_coeffs.to(device),
            "l2": l2_coeffs.to(device),
        }

    def has_dataset_coefficients(self, dataset_name: str) -> bool:
        """Check if coefficients exist for a dataset."""
        return dataset_name in self.dataset_coefficients

    def get_dataset_names(self) -> List[str]:
        """Get list of datasets with coefficients."""
        return list(self.dataset_coefficients.keys())

    def _get_coefficients_for_dataset(self, dataset_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get L1 and L2 coefficients for a specific dataset."""
        coeffs = self.dataset_coefficients[dataset_name]
        return coeffs["l1"], coeffs["l2"]

    def _compute_reg_loss(
        self,
        basis_predictions: torch.Tensor,
    ) -> torch.Tensor:
        """
        Regularization loss for the basis functions.

        Args:
            basis_predictions: Tensor of shape (batch_size, NUM_ACTIONS_CHUNK, k, action_dim)
                where k = number of basis functions.

        Returns:
            Scalar loss (torch.Tensor).
        """

        # Simple range clipping loss
        # loss = ((basis_predictions.abs() - 1).clamp_min(0)).pow(2).mean()

        # Gram matrix regularization
        B, h, k, d = basis_predictions.shape

        # Average across batch to get a representative output per basis function
        z_mean = basis_predictions.reshape(B * h, k, d).mean(dim=0)  # (k, d)

        # Gram matrix of basis functions (k x k)
        gram = z_mean @ z_mean.T  # (k, k)

        # Target is identity (orthonormal basis)
        identity = torch.eye(k, device=basis_predictions.device, dtype=gram.dtype)

        # Penalize deviation from identity
        loss = (gram - identity).pow(2).mean()

        return loss

    def forward(self, hidden_states: torch.Tensor, dataset_names: List[str]) -> torch.Tensor:
        """Forward pass with per-sample dataset-specific coefficients."""
        batch_size, _, _ = hidden_states.shape
        basis_predictions, reg_loss = self.forward_basis_functions(
            hidden_states
        )  # (batch_size, NUM_ACTIONS_CHUNK, k, action_dim), float

        # Split into continuous and discrete parts
        continuous_preds = basis_predictions[..., : self.n_continuous]  # (B, NUM_ACTIONS_CHUNK, k, n_continuous)
        discrete_preds = basis_predictions[..., self.n_continuous :]  # (B, NUM_ACTIONS_CHUNK, k, n_discrete)

        # Apply per-sample dataset-specific coefficients (batched)
        # Collect all coefficients for the batch
        l1_coeffs_batch = []
        l2_coeffs_batch = []

        for dataset_name in dataset_names:
            l1_coef, l2_coef = self._get_coefficients_for_dataset(dataset_name)
            l1_coeffs_batch.append(l1_coef)
            l2_coeffs_batch.append(l2_coef)

        # Stack coefficients: (batch_size, k)
        l1_coeffs_batch = torch.stack(l1_coeffs_batch, dim=0)  # (B, k)
        l2_coeffs_batch = torch.stack(l2_coeffs_batch, dim=0)  # (B, k)

        # Reshape for broadcasting: (B, k) -> (B, 1, k, 1) for broadcasting with (B, NUM_ACTIONS_CHUNK, k, n_*)
        l1_coeffs_batch = l1_coeffs_batch.view(batch_size, 1, self.k, 1)
        l2_coeffs_batch = l2_coeffs_batch.view(batch_size, 1, self.k, 1)

        # Apply coefficients in batch: (B, NUM_ACTIONS_CHUNK, k, n_*) * (B, 1, k, 1) -> (B, NUM_ACTIONS_CHUNK, k, n_*)
        # Then sum over k dimension: (B, NUM_ACTIONS_CHUNK, n_*)
        continuous_actions = (continuous_preds * l1_coeffs_batch).sum(dim=2)  # (B, NUM_ACTIONS_CHUNK, n_continuous)
        discrete_actions = (discrete_preds * l2_coeffs_batch).sum(dim=2)  # (B, NUM_ACTIONS_CHUNK, n_discrete)

        # Concatenate continuous and discrete
        actions = torch.cat([continuous_actions, discrete_actions], dim=-1)  # (B, NUM_ACTIONS_CHUNK, action_dim)

        return actions, reg_loss

    def predict_action(self, hidden_states: torch.Tensor, dataset_names: List[str]) -> torch.Tensor:
        """Predict actions from hidden states with dataset names."""
        return self.forward(hidden_states, dataset_names)

    def forward_basis_functions(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Get basis functions from hidden states with dataset names."""
        batch_size, _, _ = hidden_states.shape

        rearranged_actions_hidden_states = hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)

        # Simple MLP
        # h = torch.relu(self.input_projection(rearranged_actions_hidden_states))
        # h = torch.relu(self.hidden_layer(h))
        # basis_predictions = self.output_projection(h)  # (batch_size, NUM_ACTIONS_CHUNK, k * action_dim)
        basis_predictions = self.model(
            rearranged_actions_hidden_states
        )  # (batch_size, NUM_ACTIONS_CHUNK, k * action_dim)
        basis_predictions = torch.tanh(0.5 * basis_predictions)

        # Reshape to (batch_size, NUM_ACTIONS_CHUNK, k, action_dim)
        basis_predictions = basis_predictions.reshape(batch_size, NUM_ACTIONS_CHUNK, self.k, self.action_dim)

        # Compute VIC-REG loss across the k basis functions
        reg_loss = self._compute_reg_loss(basis_predictions)

        return basis_predictions, reg_loss
