import torch
import torch.nn as nn
import numpy as np
from gns_mpm import graph_network
from torch_geometric.nn import radius_graph
from typing import Dict

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import torch
import numpy as np
import time
from tqdm import tqdm

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

MAX_WORKSPACE_SIZE = 1 << 30  # 1 GB
DTYPE = trt.float32


def build_engine(onnx_file_path, num_nodes, num_edges, node_dim, edge_dim):
    EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    with trt.Builder(TRT_LOGGER) as builder, \
         builder.create_network(EXPLICIT_BATCH) as network, \
         trt.OnnxParser(network, TRT_LOGGER) as parser:

        config = builder.create_builder_config()
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, MAX_WORKSPACE_SIZE)

        if builder.platform_has_fast_fp16:
            config.set_flag(trt.BuilderFlag.FP16)

        with open(onnx_file_path, 'rb') as model_file:
            if not parser.parse(model_file.read()):
                for i in range(parser.num_errors):
                    print(parser.get_error(i))
                raise RuntimeError("ONNX parsing failed")

        profile = builder.create_optimization_profile()

        profile.set_shape("node_features",
                          min=(1, node_dim),
                          opt=(num_nodes, node_dim),
                          max=(num_nodes * 2, node_dim))

        profile.set_shape("edge_index",
                          min=(2, 1),
                          opt=(2, num_edges),
                          max=(2, num_edges * 2))

        profile.set_shape("edge_features",
                          min=(1, edge_dim),
                          opt=(num_edges, edge_dim),
                          max=(num_edges * 2, edge_dim))

        config.add_optimization_profile(profile)

        serialized_engine = builder.build_serialized_network(network, config)
        if serialized_engine is None:
            raise RuntimeError("Failed to build serialized engine")

        runtime = trt.Runtime(TRT_LOGGER)
        engine = runtime.deserialize_cuda_engine(serialized_engine)

        return engine


def allocate_buffers(engine):
    h_inputs = []
    d_inputs = []
    h_outputs = []
    d_outputs = []

    for i in range(engine.num_bindings):
        binding = engine.get_binding_name(i)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        shape = engine.get_binding_shape(i)
        size = trt.volume(shape)
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        if engine.binding_is_input(i):
            h_inputs.append(host_mem)
            d_inputs.append(device_mem)
        else:
            h_outputs.append(host_mem)
            d_outputs.append(device_mem)

    return h_inputs, d_inputs, h_outputs, d_outputs


def do_inference(context, h_inputs, d_inputs, h_outputs, d_outputs, iterations=None):
    bindings = [int(d) for d in d_inputs + d_outputs]

    for i in range(len(h_inputs)):
        cuda.memcpy_htod(d_inputs[i], h_inputs[i])

    for _ in range(10):
        context.execute_v2(bindings=bindings)

    if iterations is None:
        elapsed_time = 0
        iterations = 100
        while elapsed_time < 1:
            t_start = time.time()
            for _ in range(iterations):
                context.execute_v2(bindings=bindings)
            elapsed_time = time.time() - t_start
            iterations *= 2
        FPS = iterations / elapsed_time
        iterations = int(FPS * 3)

    t_start = time.time()
    for _ in tqdm(range(iterations)):
        context.execute_v2(bindings=bindings)
    elapsed_time = time.time() - t_start
    latency = elapsed_time / iterations * 1000
    return latency


def compute_latency_ms_tensorrt_gnn(model, num_nodes, num_edges, node_dim, edge_dim, iterations=None):
    model = model.cuda().eval()

    # Dummy inputs (use int32 for edge_index!)
    node_features = torch.randn(num_nodes, node_dim, dtype=torch.float32, device='cuda')
    edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.int32, device='cuda')
    edge_features = torch.randn(num_edges, edge_dim, dtype=torch.float32, device='cuda')

    breakpoint()
    torch.onnx.export(
        model,
        (node_features, edge_index, edge_features),
        "model.onnx",
        input_names=["node_features", "edge_index", "edge_features"],
        output_names=["output"],
        dynamic_axes={
            "node_features": {0: "num_nodes"},
            "edge_index": {1: "num_edges"},
            "edge_features": {0: "num_edges"},
        },
        opset_version=16
    )

    engine = build_engine("model.onnx", num_nodes, num_edges, node_dim, edge_dim)
    h_inputs, d_inputs, h_outputs, d_outputs = allocate_buffers(engine)

    # Copy torch tensors to host memory
    torch_inputs = [node_features, edge_index, edge_features]
    for i in range(len(torch_inputs)):
        np.copyto(h_inputs[i], torch_inputs[i].cpu().numpy().ravel())

    with engine.create_execution_context() as context:
        context.set_binding_shape(0, node_features.shape)
        context.set_binding_shape(1, edge_index.shape)
        context.set_binding_shape(2, edge_features.shape)
        latency = do_inference(context, h_inputs, d_inputs, h_outputs, d_outputs, iterations)

    return latency

import torch_tensorrt
def tensorrt_torch_latency(model, node_features, edge_index, edge_features):
  model.eval()
  scripted_model = torch.jit.trace(model, (node_features, edge_index, edge_features))

  #  TensorRT 
  trt_model = torch_tensorrt.compile(
      scripted_model,
      inputs=[
          torch_tensorrt.Input(node_features.shape),
          torch_tensorrt.Input(edge_index.shape, dtype=torch.int64),
          torch_tensorrt.Input(edge_features.shape)
      ],
      enabled_precisions={torch.float},  #  FP16 
  )
  with torch.no_grad():
    # 
    for _ in range(10):
        _ = trt_model(node_features, edge_index, edge_features)

    # 
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(994):
        _ = trt_model(node_features, edge_index, edge_features)
    torch.cuda.synchronize()
    end = time.time()

    avg_latency_ms = (end - start) / 100 * 1000
    print(f": {avg_latency_ms:.2f} ms")

class LearnedSimulator(nn.Module):
  """Learned simulator from https://arxiv.org/pdf/2002.09405.pdf."""

  def __init__(
          self,
          particle_dimensions: int,
          nnode_in: int,
          nedge_in: int,
          latent_dim: int,
          nmessage_passing_steps: int,
          nmlp_layers: int,
          mlp_hidden_dim: int,
          connectivity_radius: float,
          boundaries: np.ndarray,
          normalization_stats: dict,
          nparticle_types: int,
          particle_type_embedding_size: int,
          boundary_clamp_limit: float = 1.0,
          device="cpu"
  ):
    """Initializes the model.

    Args:
      particle_dimensions: Dimensionality of the problem.
      nnode_in: Number of node inputs.
      nedge_in: Number of edge inputs.
      latent_dim: Size of latent dimension (128)
      nmessage_passing_steps: Number of message passing steps.
      nmlp_layers: Number of hidden layers in the MLP (typically of size 2).
      connectivity_radius: Scalar with the radius of connectivity.
      boundaries: Array of 2-tuples, containing the lower and upper boundaries
        of the cuboid containing the particles along each dimensions, matching
        the dimensionality of the problem.
      normalization_stats: Dictionary with statistics with keys "acceleration"
        and "velocity", containing a named tuple for each with mean and std
        fields, matching the dimensionality of the problem.
      nparticle_types: Number of different particle types.
      particle_type_embedding_size: Embedding size for the particle type.
      boundary_clamp_limit: a factor to enlarge connectivity radius used for computing
        normalized clipped distance in edge feature.
      device: Runtime device (cuda or cpu).

    """
    super(LearnedSimulator, self).__init__()
    self._boundaries = boundaries
    self._connectivity_radius = connectivity_radius
    self._normalization_stats = normalization_stats
    self._nparticle_types = nparticle_types
    self._boundary_clamp_limit = boundary_clamp_limit

    # Particle type embedding has shape (9, 16)
    self._particle_type_embedding = nn.Embedding(
        nparticle_types, particle_type_embedding_size)

    # Initialize the EncodeProcessDecode
    # import pdb; pdb.set_trace()
    self._encode_process_decode = graph_network.EncodeProcessDecode(
        nnode_in_features=nnode_in,
        nnode_out_features=particle_dimensions,
        nedge_in_features=nedge_in,
        latent_dim=latent_dim,
        nmessage_passing_steps=nmessage_passing_steps,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim)

    self._device = device

  def forward(self):
    """Forward hook runs on class instantiation"""
    pass

  def _compute_graph_connectivity(
          self,
          node_features: torch.tensor,
          nparticles_per_example: torch.tensor,
          radius: float,
          add_self_edges: bool = True):
    """Generate graph edges to all particles within a threshold radius

    Args:
      node_features: Node features with shape (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      radius: Threshold to construct edges to all particles within the radius.
      add_self_edges: Boolean flag to include self edge (default: True)
    """
    # Specify examples id for particles
    batch_ids = torch.cat(
        [torch.LongTensor([i for _ in range(n)])
         for i, n in enumerate(nparticles_per_example)]).to(self._device)

    # radius_graph accepts r < radius not r <= radius
    # A torch tensor list of source and target nodes with shape (2, nedges)
    edge_index = radius_graph(
        node_features, r=radius, batch=batch_ids, loop=add_self_edges, max_num_neighbors=128)

    # The flow direction when using in combination with message passing is
    # "source_to_target"
    receivers = edge_index[0, :]
    senders = edge_index[1, :]

    return receivers, senders

  def _encoder_preprocessor(
          self,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None):
    """Extracts important features from the position sequence. Returns a tuple
    of node_features (nparticles, 30), edge_index (nparticles, nparticles), and
    edge_features (nparticles, 3).

    Args:
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)
    """
    # import pdb; pdb.set_trace()
    nparticles = position_sequence.shape[0]
    most_recent_position = position_sequence[:, -1]  # (n_nodes, 2)
    velocity_sequence = time_diff(position_sequence)

    # Get connectivity of the graph with shape of (nparticles, 2)
    senders, receivers = self._compute_graph_connectivity(
        most_recent_position, nparticles_per_example, self._connectivity_radius)
    node_features = []

    # Normalized velocity sequence, merging spatial an time axis.
    velocity_stats = self._normalization_stats["velocity"]
    normalized_velocity_sequence = (
        velocity_sequence - velocity_stats['mean']) / velocity_stats['std']
    flat_velocity_sequence = normalized_velocity_sequence.contiguous().view(
        nparticles, -1)
    # There are 5 previous steps, with dim 2
    # node_features shape (nparticles, 5 * 2 = 10)
    node_features.append(flat_velocity_sequence)

    # Normalized clipped distances to lower and upper boundaries.
    # boundaries are an array of shape [num_dimensions, 2], where the second
    # axis, provides the lower/upper boundaries.
    boundaries = torch.tensor(
        self._boundaries, requires_grad=False).float().to(self._device)
    distance_to_lower_boundary = (
        most_recent_position - boundaries[:, 0][None])
    distance_to_upper_boundary = (
        boundaries[:, 1][None] - most_recent_position)
    distance_to_boundaries = torch.cat(
        [distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
    normalized_clipped_distance_to_boundaries = torch.clamp(
        distance_to_boundaries / self._connectivity_radius,
        -self._boundary_clamp_limit, self._boundary_clamp_limit)
    # The distance to 4 boundaries (top/bottom/left/right)
    # node_features shape (nparticles, 10+4)
    node_features.append(normalized_clipped_distance_to_boundaries)

    # Particle type
    if self._nparticle_types > 1:
      particle_type_embeddings = self._particle_type_embedding(
          particle_types)
      node_features.append(particle_type_embeddings)
    # Final node_features shape (nparticles, 30) for 2D (if material_property is not valid in training example)
    # 30 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding

    # Material property
    if material_property is not None:
        material_property = material_property.view(nparticles, 1)
        node_features.append(material_property)
    # Final node_features shape (nparticles, 31) for 2D
    # 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property

    # Collect edge features.
    edge_features = []

    # Relative displacement and distances normalized to radius
    # with shape (nedges, 2)
    # normalized_relative_displacements = (
    #     torch.gather(most_recent_position, 0, senders) -
    #     torch.gather(most_recent_position, 0, receivers)
    # ) / self._connectivity_radius
    normalized_relative_displacements = (
        most_recent_position[senders, :] -
        most_recent_position[receivers, :]
    ) / self._connectivity_radius

    # Add relative displacement between two particles as an edge feature
    # with shape (nparticles, ndim)
    edge_features.append(normalized_relative_displacements)

    # Add relative distance between 2 particles with shape (nparticles, 1)
    # Edge features has a final shape of (nparticles, ndim + 1)
    normalized_relative_distances = torch.norm(
        normalized_relative_displacements, dim=-1, keepdim=True)
    edge_features.append(normalized_relative_distances)

    return (torch.cat(node_features, dim=-1),
            torch.stack([senders, receivers]),
            torch.cat(edge_features, dim=-1))

  def _decoder_postprocessor(
          self,
          normalized_acceleration: torch.tensor,
          position_sequence: torch.tensor) -> torch.tensor:
    """ Compute new position based on acceleration and current position.
    The model produces the output in normalized space so we apply inverse
    normalization.

    Args:
      normalized_acceleration: Normalized acceleration (nparticles, dim).
      position_sequence: Position sequence of shape (nparticles, dim).

    Returns:
      torch.tensor: New position of the particles.

    """
    # Extract real acceleration values from normalized values
    acceleration_stats = self._normalization_stats["acceleration"]
    acceleration = (
        normalized_acceleration * acceleration_stats['std']
    ) + acceleration_stats['mean']

    # Use an Euler integrator to go from acceleration to position, assuming
    # a dt=1 corresponding to the size of the finite difference.
    most_recent_position = position_sequence[:, -1]
    most_recent_velocity = most_recent_position - position_sequence[:, -2]

    # TODO: Fix dt
    new_velocity = most_recent_velocity + acceleration  # * dt = 1
    new_position = most_recent_position + new_velocity  # * dt = 1
    return new_position, new_velocity, acceleration

  def predict_positions(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types)
    # breakpoint()
    # # latency = compute_latency_ms_tensorrt_gnn(self._encode_process_decode, num_nodes=2368, num_edges=71122, node_dim=30, edge_dim=3)
    # latency = tensorrt_torch_latency(self._encode_process_decode, node_features, edge_index, edge_features)
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    next_positions, next_velocity, next_acc  = self._decoder_postprocessor(
        predicted_normalized_acceleration, current_positions)
    return next_positions, next_velocity, next_acc
  

  def predict_positions_index(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types)
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    next_positions, next_velocity, next_acc  = self._decoder_postprocessor(
        predicted_normalized_acceleration, current_positions)
    return next_positions, next_velocity, next_acc, edge_index


  def predict_accelerations(
          self,
          next_positions: torch.tensor,
          position_sequence_noise: torch.tensor,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None):
    """Produces normalized and predicted acceleration targets.

    Args:
      next_positions: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence_noise: Tensor of the same shape as `position_sequence`
        with the noise to apply to each particle.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles).

    Returns:
      Tensors of shape (nparticles_in_batch, dim) with the predicted and target
        normalized accelerations.

    """

    # Add noise to the input position sequence.
    # import pdb; pdb.set_trace()
    noisy_position_sequence = position_sequence + position_sequence_noise

    # Perform the forward pass with the noisy position sequence.
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types)    # node_features([1356, 30]) particle-number, feature
    predicted_normalized_acceleration = self._encode_process_decode(
        node_features, edge_index, edge_features)
    # Calculate the target acceleration, using an `adjusted_next_position `that
    # is shifted by the noise in the last input position.
    next_position_adjusted = next_positions + position_sequence_noise[:, -1]
    target_normalized_acceleration = self._inverse_decoder_postprocessor(
        next_position_adjusted, noisy_position_sequence)
    # As a result the inverted Euler update in the `_inverse_decoder` produces:
    # * A target acceleration that does not explicitly correct for the noise in
    #   the input positions, as the `next_position_adjusted` is different
    #   from the true `next_position`.
    # * A target acceleration that exactly corrects noise in the input velocity
    #   since the target next velocity calculated by the inverse Euler update
    #   as `next_position_adjusted - noisy_position_sequence[:,-1]`
    #   matches the ground truth next velocity (noise cancels out).

    return predicted_normalized_acceleration, target_normalized_acceleration

  def _inverse_decoder_postprocessor(
          self,
          next_position: torch.tensor,
          position_sequence: torch.tensor):
    """Inverse of `_decoder_postprocessor`.

    Args:
      next_position: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.

    Returns:
      normalized_acceleration (torch.tensor): Normalized acceleration.

    """
    previous_position = position_sequence[:, -1]
    previous_velocity = previous_position - position_sequence[:, -2]
    next_velocity = next_position - previous_position
    acceleration = next_velocity - previous_velocity

    acceleration_stats = self._normalization_stats["acceleration"]
    normalized_acceleration = (
        acceleration - acceleration_stats['mean']) / acceleration_stats['std']
    return normalized_acceleration

  def save(
          self,
          path: str = 'model.pt'):
    """Save model state

    Args:
      path: Model path
    """
    torch.save(self.state_dict(), path)

  def load(
          self,
          path: str):
    """Load model state from file

    Args:
      path: Model path
    """
    self.load_state_dict(torch.load(path, map_location=torch.device('cpu')))


def time_diff(
        position_sequence: torch.tensor) -> torch.tensor:
  """Finite difference between two input position sequence

  Args:
    position_sequence: Input position sequence & shape(nparticles, 6 steps, dim)

  Returns:
    torch.tensor: Velocity sequence
  """
  return position_sequence[:, 1:] - position_sequence[:, :-1]
