import math

import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple, Type
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
# from prismatic.vla.constants import ACTION_DIM, NUM_ACTIONS_CHUNK
ACTION_DIM=2  #7 for robot 
NUM_ACTIONS_CHUNK=10 # 8 for robot

class SinusoidalPositionalEncoding(nn.Module):
    """
    Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps.

    For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,)
    Then the output would be a batch of 32 timestep embeddings -> shape (32, D)

    Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim  # dimensionality of the positional encoding

    def forward(self, x):
        # x: (batch_size,)
        device = x.device
        assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}"
        half_dim = self.dim // 2
        exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1)  # shape: (D/2,)
        emb = torch.exp(exponent)  # shape: (D/2,)
        emb = x[:, None] * emb[None, :]  # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)  # shape: (batch_size, D)
        return emb

class MLPResNetBlock(nn.Module):
    """One MLP ResNet block with a residual connection."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.ffn = nn.Sequential(  # feedforward network, similar to the ones in Transformers
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.ReLU(),
        )

    def forward(self, x):
        # x: (batch_size, hidden_dim)
        # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as
        # described here: https://arxiv.org/pdf/2002.04745.pdf
        identity = x
        x = self.ffn(x)
        x = x + identity
        return x


class MLPResNet(nn.Module):
    """MLP with residual connection blocks."""
    def __init__(self, num_blocks, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.mlp_resnet_blocks = nn.ModuleList()
        for _ in range(num_blocks):
            self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim))
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x: (batch_size, input_dim)
        x = self.layer_norm1(x)  # shape: (batch_size, input_dim)
        x = self.fc1(x)  # shape: (batch_size, hidden_dim)
        x = self.relu(x)  # shape: (batch_size, hidden_dim)
        for block in self.mlp_resnet_blocks:
            x = block(x)  # shape: (batch_size, hidden_dim)
        x = self.layer_norm2(x)  # shape: (batch_size, hidden_dim)
        x = self.fc2(x)  # shape: (batch_size, output_dim)
        return x


class RegressionActionHead(nn.Module):
    """Simple MLP-based action head that generates continuous actions via L1 regression."""
    def __init__(
        self,
        input_dim=3584,#4096,
        hidden_dim=3584,#4096,
        action_dim=7,
    ):
        super().__init__()
        self.action_dim = action_dim
        self.model = MLPResNet(
            num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim
        )
        # self.project = nn.Linear(3584, 10)

    # def predict_action(self, actions_hidden_states):
    def forward(self, actions_hidden_states):
        # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence
        # - shape: (batch_size, chunk_len * action_dim, hidden_dim)
        # ground_truth_actions: ground-truth actions
        # - shape: (batch_size, chunk_len, action_dim)
        # import pdb; pdb.set_trace()
        batch_size = actions_hidden_states.shape[0]
        device = actions_hidden_states.device
        #
        # actions_hidden_states = actions_hidden_states.transpose(1, 2)
        # actions_hidden_states = self.project(actions_hidden_states)
        rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1)
        action = self.model(rearranged_actions_hidden_states)
        return action

# def init_module(
#     module_class: Type[nn.Module],
#     module_name: str,
#     cfg: FinetuneConfig,
#     device_id: int,
#     module_args: dict,
#     to_bf16: bool = False,
#     find_unused_params: bool = False,
# ) -> DDP:
#     """
#     Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP.

#     Args:
#         module_class (Type[nn.Module]): Class of PyTorch module to initialize.
#         module_name (str): Name of model component to load checkpoint for.
#         cfg (FinetuneConfig): Training configuration.
#         device_id (str): Device ID.
#         module_args (dict): Args for initializing the module.
#         to_bf16 (bool): Whether to convert to torch.bfloat16 data type.
#         find_unused_params (bool): Whether to detect parameters without gradients in distributed training.

#     Returns:
#         DistributedDataParallel: PyTorch module wrapped with DDP.
#     """
#     module = module_class(**module_args)
#     count_parameters(module, module_name)

#     if cfg.resume:
#         state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step)
#         module.load_state_dict(state_dict)

#     if to_bf16:
#         module = module.to(torch.bfloat16)
#     module = module.to(device_id)

#     return wrap_ddp(module, device_id, find_unused_params)