import torch.nn as nn
import numpy as np
from gns import graph_network
from torch_geometric.nn import radius_graph
from typing import Dict

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # Ensures CUDA context initialization
import torch
import torch.nn as nn
import numpy as np
import time
import os
from tqdm import tqdm
from typing import List, Dict, Tuple, Any, Optional
from typing import cast

# TensorRT logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# Define constants
MAX_WORKSPACE_SIZE = 1 << 30  # 1 GB
DEFAULT_ONNX_FILE = "model.onnx"

def build_engine(
    onnx_file_path: str,
    engine_file_path: Optional[str] = None,
    use_fp16: bool = False
) -> trt.ICudaEngine:
    """
    Builds a TensorRT engine from an ONNX file.

    Args:
        onnx_file_path: Path to the ONNX model file.
        engine_file_path: Optional path to save/load the serialized engine.
        use_fp16: Whether to enable FP16 precision.

    Returns:
        A TensorRT ICudaEngine.
    """
    if engine_file_path and os.path.exists(engine_file_path):
        print(f"Loading engine from file: {engine_file_path}")
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())

    EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(EXPLICIT_BATCH) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, MAX_WORKSPACE_SIZE)

        if use_fp16:
            if builder.platform_has_fast_fp16:
                config.set_flag(trt.BuilderFlag.FP16)
                print("FP16 mode enabled.")
            else:
                print("FP16 mode requested, but not supported on this platform.")

        print(f"Parsing ONNX file: {onnx_file_path}")
        with open(onnx_file_path, 'rb') as model_file:
            if not parser.parse(model_file.read()):
                print("ERROR: ONNX parsing failed.")
                for error_idx in range(parser.num_errors):
                    print(parser.get_error(error_idx))
                raise RuntimeError("ONNX parsing failed")
        print("ONNX parsing successful.")

        # Print network information for debugging
        print("\nNetwork Information:")
        for i in range(network.num_inputs):
            tensor = network.get_input(i)
            print(f"Input {i}: {tensor.name}, shape: {tensor.shape}")
        for i in range(network.num_outputs):
            tensor = network.get_output(i)
            print(f"Output {i}: {tensor.name}, shape: {tensor.shape}")

        print("\nBuilding TensorRT engine... (This may take a while)")
        serialized_engine = builder.build_serialized_network(network, config)
        if serialized_engine is None:
            raise RuntimeError("Failed to build serialized engine")
        print("TensorRT engine built successfully.")

        if engine_file_path:
            print(f"Saving engine to file: {engine_file_path}")
            with open(engine_file_path, "wb") as f:
                f.write(serialized_engine)
        
        with trt.Runtime(TRT_LOGGER) as runtime:
            engine = runtime.deserialize_cuda_engine(serialized_engine)
            # Print engine information for debugging
            print("\nEngine Information:")
            for i in range(engine.num_io_tensors):
                name = engine.get_tensor_name(i)
                print(f"Tensor {i}: {name}")
        return engine

def allocate_buffers(
    engine: trt.ICudaEngine,
    context: trt.IExecutionContext
) -> Tuple[Dict[str, Any], Dict[str, Any], List[int]]:
    """ Allocates host and device buffers for TRT engine bindings. """
    host_buffers_map: Dict[str, np.ndarray] = {}
    device_buffers_map: Dict[str, cuda.DeviceAllocation] = {}
    # Initialize with correct number of bindings, will be filled with device pointers
    num_bindings = engine.num_io_tensors
    ordered_device_buffers: List[Optional[int]] = [None] * num_bindings

    for i in range(num_bindings):
        binding_name = engine.get_tensor_name(i)
        # Get shape from the engine
        shape = engine.get_tensor_shape(binding_name)
        dtype = trt.nptype(engine.get_tensor_dtype(binding_name))
        
        # If a dimension is still dynamic (-1), it's problematic for allocation.
        if any(dim == -1 for dim in shape):
            raise ValueError(
                f"Binding '{binding_name}' has dynamic dimension {shape}. "
                "Cannot determine allocation size. Please ensure all dimensions are fixed."
            )

        try:
            size = trt.volume(shape)
        except Exception as e:
            raise ValueError(f"Error calculating volume for binding '{binding_name}' with shape {shape}. Original error: {e}")
        
        if size < 0: # Should not happen if trt.volume worked and shape is valid
            raise ValueError(f"Calculated size for binding '{binding_name}' is negative ({size}) with shape {shape}.")

        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)

        host_buffers_map[binding_name] = host_mem
        device_buffers_map[binding_name] = device_mem
        ordered_device_buffers[i] = int(device_mem) # Use index directly since we're iterating in order

    # Ensure all device buffer slots were filled
    if any(buf_ptr is None for buf_ptr in ordered_device_buffers):
        raise RuntimeError("Failed to correctly order all device buffers. Some binding indices might be missing.")
    
    return host_buffers_map, device_buffers_map, cast(List[int], ordered_device_buffers)


def do_inference(
    context: trt.IExecutionContext,
    host_buffers_map: Dict[str, np.ndarray],
    device_buffers_map: Dict[str, cuda.DeviceAllocation],
    ordered_device_buffers: List[int],
    input_names: List[str],
    output_names: List[str],
    iterations: Optional[int] = None
) -> float:
    """ Performs inference using the TensorRT engine. """
    # Copy input data from host to device
    for name in input_names:
        if name not in device_buffers_map or name not in host_buffers_map:
            raise KeyError(f"Input tensor name '{name}' not found in allocated buffers.")
        cuda.memcpy_htod(device_buffers_map[name], host_buffers_map[name])

    # Warm-up runs
    for _ in range(10):
        context.execute_v2(bindings=ordered_device_buffers)

    if iterations is None:
        temp_iterations = 100
        elapsed_time_s = 0
        # Adaptive iteration count logic
        while elapsed_time_s < 0.5 and temp_iterations <= 6400: # Max iterations to prevent infinite loop
            t_start = time.monotonic() # Use monotonic for timing
            for _ in range(temp_iterations):
                context.execute_v2(bindings=ordered_device_buffers)
            cuda.Context.synchronize() # Ensure all GPU operations are done before measuring time
            elapsed_time_s = time.monotonic() - t_start
            
            if elapsed_time_s < 1e-5: # Extremely fast, increase iterations significantly
                temp_iterations *= 10
            else:
                # Estimate iterations for ~1-2s measurement time
                est_iter_for_1s = temp_iterations / elapsed_time_s
                iterations = max(temp_iterations, int(est_iter_for_1s * 1.5)) # Aim for 1.5s
                iterations = min(iterations, temp_iterations * 10) # Don't increase too drastically
            if temp_iterations > 20000: break # Safety break
        iterations = max(50, iterations or 100)
        print(f"Auto-determined iterations for measurement: {iterations}")
    
    # Timed inference
    t_start = time.monotonic()
    for _ in tqdm(range(iterations), desc="Measuring Latency"):
        context.execute_v2(bindings=ordered_device_buffers)
    cuda.Context.synchronize() # Crucial for accurate timing
    elapsed_time_s = time.monotonic() - t_start

    # Copy output data from device to host
    for name in output_names:
        if name not in device_buffers_map or name not in host_buffers_map:
            raise KeyError(f"Output tensor name '{name}' not found in allocated buffers.")
        cuda.memcpy_dtoh(host_buffers_map[name], device_buffers_map[name])

    latency_ms = (elapsed_time_s / iterations) * 1000
    return latency_ms

def compute_latency_ms_tensorrt_gnn(
    model: torch.nn.Module,
    num_nodes: int,
    num_edges: int,
    nnode_in_dim: int, 
    nedge_in_dim: int,
    onnx_input_names: List[str],
    onnx_output_names: List[str],
    onnx_file_path: str = DEFAULT_ONNX_FILE,
    engine_file_path: Optional[str] = None,
    use_fp16: bool = False,
    iterations: Optional[int] = None
) -> float:
    """ Computes the inference latency of a GNN model using TensorRT. """
    model = model.cuda().eval()

    # Prepare dummy inputs based on nnode_in_dim and nedge_in_dim
    node_features_torch = torch.randn(num_nodes, nnode_in_dim, dtype=torch.float32, device='cuda')
    # Ensure edge_index values are within [0, num_nodes-1]
    if num_nodes <= 0: raise ValueError("num_nodes must be positive for edge_index generation.")
    edge_index_torch = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.int32, device='cuda')
    edge_features_torch = torch.randn(num_edges, nedge_in_dim, dtype=torch.float32, device='cuda')
    
    dummy_inputs_tuple = (node_features_torch, edge_index_torch, edge_features_torch)
    
    if len(onnx_input_names) != len(dummy_inputs_tuple):
        raise ValueError("Length of onnx_input_names must match the number of dummy inputs.")
    torch_inputs_map = {name: dummy_inputs_tuple[i] for i, name in enumerate(onnx_input_names)}

    print(f"Exporting model to ONNX: {onnx_file_path}")
    torch.onnx.export(
        model,
        dummy_inputs_tuple,
        onnx_file_path,
        input_names=onnx_input_names,
        output_names=onnx_output_names,
        opset_version=17,  # Using 17 for better compatibility with modern ops
        do_constant_folding=True,  # Optimize constant folding
        verbose=True  # Enable verbose output to see the actual tensor names
    )
    print("Model exported to ONNX successfully.")

    try:
        engine = build_engine(
            onnx_file_path,
            engine_file_path=engine_file_path,
            use_fp16=use_fp16
        )

        with engine.create_execution_context() as context:
            # Create a mapping of binding names to indices using TensorRT 10.0.1 API
            binding_name_to_idx = {}
            num_bindings = engine.num_io_tensors
            print("\nAvailable bindings in engine:")
            for i in range(num_bindings):
                name = engine.get_tensor_name(i)
                binding_name_to_idx[name] = i
                print(f"Binding {i}: {name}")
            
            print("\nInput tensors we're trying to use:")
            for name in onnx_input_names:
                print(f"Input name: {name}")
            
            for name, tensor_data in torch_inputs_map.items():
                if name not in binding_name_to_idx:
                    raise ValueError(f"Input tensor '{name}' not found in engine bindings. Available bindings: {list(binding_name_to_idx.keys())}")
                binding_idx = binding_name_to_idx[name]
                current_shape = tensor_data.shape
                print(f"Input tensor '{name}' shape: {current_shape}")

            # Allocate buffers
            host_buffers, device_buffers, ordered_dev_bufs = allocate_buffers(engine, context)

            # Copy PyTorch input tensors to host buffers
            for name, tensor_data in torch_inputs_map.items():
                if name not in host_buffers:
                    print(f"Warning: Input tensor '{name}' not found in host_buffers. Skipping copy.")
                    continue
                
                numpy_data = tensor_data.cpu().numpy()
                if host_buffers[name].shape != numpy_data.shape:
                    if host_buffers[name].size == numpy_data.size:
                        np.copyto(host_buffers[name].reshape(-1), numpy_data.ravel())
                        print(f"Copied '{name}' using reshaped/ravel due to shape mismatch: host {host_buffers[name].shape}, tensor {numpy_data.shape}")
                    else:
                        raise ValueError(f"Size mismatch for input {name}: host buffer size {host_buffers[name].size} (shape {host_buffers[name].shape}), "
                                     f"tensor data size {numpy_data.size} (shape {numpy_data.shape})")
                else:
                    np.copyto(host_buffers[name], numpy_data)

            # Run inference and measure latency
            latency = do_inference(
                context, host_buffers, device_buffers, ordered_dev_bufs,
                onnx_input_names, onnx_output_names, iterations
            )
    finally:
        if os.path.exists(onnx_file_path):
            print(f"Cleaning up ONNX file: {onnx_file_path}")
            os.remove(onnx_file_path)

    return latency


import torch_tensorrt
def tensorrt_torch_latency(model, node_features, edge_index, edge_features):
  model.eval()
  scripted_model = torch.jit.trace(model, (node_features, edge_index, edge_features))
  edge_index = edge_index.to(torch.int32)
  
  #  TensorRT 
  trt_model = torch_tensorrt.compile(
      scripted_model,
      inputs=[
          torch_tensorrt.Input(node_features.shape, dtype=torch.float32),
          torch_tensorrt.Input(edge_index.shape, dtype=torch.int32),
          torch_tensorrt.Input(edge_features.shape, dtype=torch.float32)
      ],
      enabled_precisions={torch.float},  #  FP16 
  )
  with torch.no_grad():
    # 
    breakpoint()
    for _ in range(10):
        _ = trt_model(node_features, edge_index, edge_features)

    # 
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(994):
        _ = trt_model(node_features, edge_index, edge_features)
    torch.cuda.synchronize()
    end = time.time()

    avg_latency_ms = (end - start) #/ 100 * 1000
    print(f": {avg_latency_ms:.2f} s")


class LearnedSimulator(nn.Module):
  """Learned simulator from https://arxiv.org/pdf/2002.09405.pdf."""

  def __init__(
          self,
          particle_dimensions: int,
          nnode_in: int,
          nedge_in: int,
          latent_dim: int,
          nmessage_passing_steps: int,
          nmlp_layers: int,
          mlp_hidden_dim: int,
          connectivity_radius: float,
          boundaries: np.ndarray,
          normalization_stats: dict,
          nparticle_types: int,
          particle_type_embedding_size: int,
          boundary_clamp_limit: float = 1.0,
          device="cpu"
  ):
    """Initializes the model.

    Args:
      particle_dimensions: Dimensionality of the problem.
      nnode_in: Number of node inputs.
      nedge_in: Number of edge inputs.
      latent_dim: Size of latent dimension (128)
      nmessage_passing_steps: Number of message passing steps.
      nmlp_layers: Number of hidden layers in the MLP (typically of size 2).
      connectivity_radius: Scalar with the radius of connectivity.
      boundaries: Array of 2-tuples, containing the lower and upper boundaries
        of the cuboid containing the particles along each dimensions, matching
        the dimensionality of the problem.
      normalization_stats: Dictionary with statistics with keys "acceleration"
        and "velocity", containing a named tuple for each with mean and std
        fields, matching the dimensionality of the problem.
      nparticle_types: Number of different particle types.
      particle_type_embedding_size: Embedding size for the particle type.
      boundary_clamp_limit: a factor to enlarge connectivity radius used for computing
        normalized clipped distance in edge feature.
      device: Runtime device (cuda or cpu).

    """
    super(LearnedSimulator, self).__init__()
    self._boundaries = boundaries
    self._connectivity_radius = connectivity_radius
    self._normalization_stats = normalization_stats
    self._nparticle_types = nparticle_types
    self._boundary_clamp_limit = boundary_clamp_limit

    # Particle type embedding has shape (9, 16)
    self._particle_type_embedding = nn.Embedding(
        nparticle_types, particle_type_embedding_size)

    # Initialize the EncodeProcessDecode
    # import pdb; pdb.set_trace()
    self._encode_process_decode = graph_network.EncodeProcessDecode(
        nnode_in_features=nnode_in,
        nnode_out_features=particle_dimensions,
        nedge_in_features=nedge_in,
        latent_dim=latent_dim,
        nmessage_passing_steps=nmessage_passing_steps,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim)

    self._device = device

  def forward(self):
    """Forward hook runs on class instantiation"""
    pass

  def _compute_graph_connectivity(
          self,
          node_features: torch.tensor,
          nparticles_per_example: torch.tensor,
          radius: float,
          add_self_edges: bool = True):
    """Generate graph edges to all particles within a threshold radius

    Args:
      node_features: Node features with shape (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      radius: Threshold to construct edges to all particles within the radius.
      add_self_edges: Boolean flag to include self edge (default: True)
    """
    # Specify examples id for particles
    batch_ids = torch.cat(
        [torch.LongTensor([i for _ in range(n)])
         for i, n in enumerate(nparticles_per_example)]).to(self._device)

    # radius_graph accepts r < radius not r <= radius
    # A torch tensor list of source and target nodes with shape (2, nedges)
    edge_index = radius_graph(
        node_features, r=radius, batch=batch_ids, loop=add_self_edges, max_num_neighbors=128)

    # The flow direction when using in combination with message passing is
    # "source_to_target"
    receivers = edge_index[0, :]
    senders = edge_index[1, :]

    return receivers, senders

  def _encoder_preprocessor(
          self,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None):
    """Extracts important features from the position sequence. Returns a tuple
    of node_features (nparticles, 30), edge_index (nparticles, nparticles), and
    edge_features (nparticles, 3).

    Args:
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)
    """
    # import pdb; pdb.set_trace()
    nparticles = position_sequence.shape[0]
    most_recent_position = position_sequence[:, -1]  # (n_nodes, 2)
    velocity_sequence = time_diff(position_sequence)

    # Get connectivity of the graph with shape of (nparticles, 2)
    senders, receivers = self._compute_graph_connectivity(
        most_recent_position, nparticles_per_example, self._connectivity_radius)
    node_features = []

    # Normalized velocity sequence, merging spatial an time axis.
    velocity_stats = self._normalization_stats["velocity"]
    normalized_velocity_sequence = (
        velocity_sequence - velocity_stats['mean']) / velocity_stats['std']
    flat_velocity_sequence = normalized_velocity_sequence.view(
        nparticles, -1)
    # There are 5 previous steps, with dim 2
    # node_features shape (nparticles, 5 * 2 = 10)
    node_features.append(flat_velocity_sequence)

    # Normalized clipped distances to lower and upper boundaries.
    # boundaries are an array of shape [num_dimensions, 2], where the second
    # axis, provides the lower/upper boundaries.
    boundaries = torch.tensor(
        self._boundaries, requires_grad=False).float().to(self._device)
    distance_to_lower_boundary = (
        most_recent_position - boundaries[:, 0][None])
    distance_to_upper_boundary = (
        boundaries[:, 1][None] - most_recent_position)
    distance_to_boundaries = torch.cat(
        [distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
    normalized_clipped_distance_to_boundaries = torch.clamp(
        distance_to_boundaries / self._connectivity_radius,
        -self._boundary_clamp_limit, self._boundary_clamp_limit)
    # The distance to 4 boundaries (top/bottom/left/right)
    # node_features shape (nparticles, 10+4)
    node_features.append(normalized_clipped_distance_to_boundaries)

    # Particle type
    if self._nparticle_types > 1:
      particle_type_embeddings = self._particle_type_embedding(
          particle_types)
      node_features.append(particle_type_embeddings)
    # Final node_features shape (nparticles, 30) for 2D (if material_property is not valid in training example)
    # 30 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding

    # Material property
    if material_property is not None:
        material_property = material_property.view(nparticles, 1)
        node_features.append(material_property)
    # Final node_features shape (nparticles, 31) for 2D
    # 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property

    # Collect edge features.
    edge_features = []

    # Relative displacement and distances normalized to radius
    # with shape (nedges, 2)
    # normalized_relative_displacements = (
    #     torch.gather(most_recent_position, 0, senders) -
    #     torch.gather(most_recent_position, 0, receivers)
    # ) / self._connectivity_radius
    normalized_relative_displacements = (
        most_recent_position[senders, :] -
        most_recent_position[receivers, :]
    ) / self._connectivity_radius

    # Add relative displacement between two particles as an edge feature
    # with shape (nparticles, ndim)
    edge_features.append(normalized_relative_displacements)

    # Add relative distance between 2 particles with shape (nparticles, 1)
    # Edge features has a final shape of (nparticles, ndim + 1)
    normalized_relative_distances = torch.norm(
        normalized_relative_displacements, dim=-1, keepdim=True)
    edge_features.append(normalized_relative_distances)

    return (torch.cat(node_features, dim=-1),
            torch.stack([senders, receivers]),
            torch.cat(edge_features, dim=-1))

  def _decoder_postprocessor(
          self,
          normalized_acceleration: torch.tensor,
          position_sequence: torch.tensor) -> torch.tensor:
    """ Compute new position based on acceleration and current position.
    The model produces the output in normalized space so we apply inverse
    normalization.

    Args:
      normalized_acceleration: Normalized acceleration (nparticles, dim).
      position_sequence: Position sequence of shape (nparticles, dim).

    Returns:
      torch.tensor: New position of the particles.

    """
    # Extract real acceleration values from normalized values
    acceleration_stats = self._normalization_stats["acceleration"]
    acceleration = (
        normalized_acceleration * acceleration_stats['std']
    ) + acceleration_stats['mean']

    # Use an Euler integrator to go from acceleration to position, assuming
    # a dt=1 corresponding to the size of the finite difference.
    most_recent_position = position_sequence[:, -1]
    most_recent_velocity = most_recent_position - position_sequence[:, -2]

    # TODO: Fix dt
    new_velocity = most_recent_velocity + acceleration  # * dt = 1
    new_position = most_recent_position + new_velocity  # * dt = 1
    return new_position, new_velocity, acceleration

  def predict_positions(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types)
    # breakpoint()
    # latency = compute_latency_ms_tensorrt_gnn(self._encode_process_decode, num_nodes=5, num_edges=20, node_dim=30, edge_dim=3, iterations=994)
    # latency = compute_latency_ms_tensorrt_gnn(self._encode_process_decode, num_nodes=2368, num_edges=71122, node_dim=30, edge_dim=3, iterations=994)
    # latency = tensorrt_torch_latency(self._encode_process_decode, node_features, edge_index, edge_features)
    # latency = self.test_tensorrt_performance(num_nodes=200, num_edges=1000, iterations=994)
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    next_positions, next_velocity, next_acc  = self._decoder_postprocessor(
        predicted_normalized_acceleration, current_positions)
    return next_positions, next_velocity, next_acc
  

  def predict_positions_index(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types)
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    next_positions, next_velocity, next_acc  = self._decoder_postprocessor(
        predicted_normalized_acceleration, current_positions)
    return next_positions, next_velocity, next_acc, edge_index


  def predict_accelerations(
          self,
          next_positions: torch.tensor,
          position_sequence_noise: torch.tensor,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None):
    """Produces normalized and predicted acceleration targets.

    Args:
      next_positions: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence_noise: Tensor of the same shape as `position_sequence`
        with the noise to apply to each particle.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles).

    Returns:
      Tensors of shape (nparticles_in_batch, dim) with the predicted and target
        normalized accelerations.

    """

    # Add noise to the input position sequence.
    # import pdb; pdb.set_trace()
    noisy_position_sequence = position_sequence + position_sequence_noise

    # Perform the forward pass with the noisy position sequence.
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types)    # node_features([1356, 30]) particle-number, feature
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    # Calculate the target acceleration, using an `adjusted_next_position `that
    # is shifted by the noise in the last input position.
    next_position_adjusted = next_positions + position_sequence_noise[:, -1]
    target_normalized_acceleration = self._inverse_decoder_postprocessor(
        next_position_adjusted, noisy_position_sequence)
    # As a result the inverted Euler update in the `_inverse_decoder` produces:
    # * A target acceleration that does not explicitly correct for the noise in
    #   the input positions, as the `next_position_adjusted` is different
    #   from the true `next_position`.
    # * A target acceleration that exactly corrects noise in the input velocity
    #   since the target next velocity calculated by the inverse Euler update
    #   as `next_position_adjusted - noisy_position_sequence[:,-1]`
    #   matches the ground truth next velocity (noise cancels out).

    return predicted_normalized_acceleration, target_normalized_acceleration

  def _inverse_decoder_postprocessor(
          self,
          next_position: torch.tensor,
          position_sequence: torch.tensor):
    """Inverse of `_decoder_postprocessor`.

    Args:
      next_position: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.

    Returns:
      normalized_acceleration (torch.tensor): Normalized acceleration.

    """
    previous_position = position_sequence[:, -1]
    previous_velocity = previous_position - position_sequence[:, -2]
    next_velocity = next_position - previous_position
    acceleration = next_velocity - previous_velocity

    acceleration_stats = self._normalization_stats["acceleration"]
    normalized_acceleration = (
        acceleration - acceleration_stats['mean']) / acceleration_stats['std']
    return normalized_acceleration

  def save(
          self,
          path: str = 'model.pt'):
    """Save model state

    Args:
      path: Model path
    """
    torch.save(self.state_dict(), path)

  def load(
          self,
          path: str):
    """Load model state from file

    Args:
      path: Model path
    """
    # self.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
    state_dict = torch.load(path, map_location=torch.device('cpu'))
    self.load_state_dict(state_dict, strict=False)

  def test_tensorrt_performance(
          self,
          num_nodes: int,
          num_edges: int,
          iterations: Optional[int] = None,
          use_fp16: bool = False,
          engine_file_path: Optional[str] = None) -> float:
    """Test the performance of the GNN model using TensorRT.

    Args:
      num_nodes: Number of nodes in the graph
      num_edges: Number of edges in the graph
      iterations: Number of iterations for performance testing
      use_fp16: Whether to use FP16 precision
      engine_file_path: Path to save/load the TensorRT engine

    Returns:
      float: Average inference latency in milliseconds
    """
    # Define input and output names for ONNX export
    onnx_input_names = ["node_features", "edge_index", "edge_features"]  # More descriptive names
    onnx_output_names = ["output"]

    # Get input dimensions from the model
    nnode_in = self._encode_process_decode.nnode_in
    nedge_in = self._encode_process_decode.nedge_in

    # Create dummy inputs for testing
    node_features = torch.randn(num_nodes, nnode_in, dtype=torch.float32, device='cuda')
    edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.int32, device='cuda')
    edge_features = torch.randn(num_edges, nedge_in, dtype=torch.float32, device='cuda')

    # Test TensorRT performance
    latency = compute_latency_ms_tensorrt_gnn(
        model=self._encode_process_decode,
        num_nodes=num_nodes,
        num_edges=num_edges,
        nnode_in_dim=nnode_in,
        nedge_in_dim=nedge_in,
        onnx_input_names=onnx_input_names,
        onnx_output_names=onnx_output_names,
        use_fp16=use_fp16,
        engine_file_path=engine_file_path,
        iterations=iterations
    )

    return latency

def time_diff(
        position_sequence: torch.tensor) -> torch.tensor:
  """Finite difference between two input position sequence

  Args:
    position_sequence: Input position sequence & shape(nparticles, 6 steps, dim)

  Returns:
    torch.tensor: Velocity sequence
  """
  return position_sequence[:, 1:] - position_sequence[:, :-1]
