import torch
import torch.nn as nn
import numpy as np
from Gnn_ddim import graph_network
from torch_geometric.nn import radius_graph
from torch_geometric.utils import to_dense_adj, to_dense_batch
import torch_geometric.utils
from typing import Dict

import torch.nn.functional as F
from src.diffusion.noise_schedule import PredefinedNoiseSchedule
from src.diffusion import diffusion_utils
from src import utils

from src.models.transformer_model import GraphTransformer

from src.diffusion.noise_schedule import DiscreteUniformTransition, PredefinedNoiseScheduleDiscrete,\
    MarginalUniformTransition

from improved_diffusion import dist_util
from improved_diffusion.resample import UniformSampler
from improved_diffusion.script_util import create_gaussian_diffusion
from Gnn_ddim.control_network import Control_adapter

class LearnedSimulator_control(nn.Module):
  """Learned simulator from https://arxiv.org/pdf/2002.09405.pdf."""

  def __init__(
          self,
          particle_dimensions: int,
          nnode_in: int,
          nedge_in: int,
          latent_dim: int,
          nmessage_passing_steps: int,
          nmlp_layers: int,
          mlp_hidden_dim: int,
          connectivity_radius: float,
          boundaries: np.ndarray,
          normalization_stats: dict,
          nparticle_types: int,
          particle_type_embedding_size: int,
          boundary_clamp_limit: float = 1.0,
          control_timestep_embed_dim = 16,
          channels=[32, 64, 128, 128], 
          nums_rb=2, 
          cin=192,
          dt = 0.0025,
          mode='train',
          device="cpu"
  ):
    """Initializes the model.

    Args:
      particle_dimensions: Dimensionality of the problem.
      nnode_in: Number of node inputs.
      nedge_in: Number of edge inputs.
      latent_dim: Size of latent dimension (128)
      nmessage_passing_steps: Number of message passing steps.
      nmlp_layers: Number of hidden layers in the MLP (typically of size 2).
      connectivity_radius: Scalar with the radius of connectivity.
      boundaries: Array of 2-tuples, containing the lower and upper boundaries
        of the cuboid containing the particles along each dimensions, matching
        the dimensionality of the problem.
      normalization_stats: Dictionary with statistics with keys "acceleration"
        and "velocity", containing a named tuple for each with mean and std
        fields, matching the dimensionality of the problem.
      nparticle_types: Number of different particle types.
      particle_type_embedding_size: Embedding size for the particle type.
      boundary_clamp_limit: a factor to enlarge connectivity radius used for computing
        normalized clipped distance in edge feature.
      device: Runtime device (cuda or cpu).

    """
    super(LearnedSimulator_control, self).__init__()
    self._boundaries = boundaries
    self._connectivity_radius = connectivity_radius
    self._normalization_stats = normalization_stats
    self._nparticle_types = nparticle_types
    self._boundary_clamp_limit = boundary_clamp_limit

    # Particle type embedding has shape (9, 16)
    self._particle_type_embedding = nn.Embedding(
        nparticle_types, particle_type_embedding_size)

    # Initialize the EncodeProcessDecode
    # import pdb; pdb.set_trace()
    self._encode_process_decode = graph_network.EncodeProcessDecode(
        nnode_in_features=nnode_in,
        nnode_out_features=particle_dimensions,
        nedge_in_features=nedge_in,
        latent_dim=latent_dim,
        nmessage_passing_steps=nmessage_passing_steps,
        nmlp_layers=nmlp_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        timestep_embed_dim=16)
    
    self.dt=dt
    self.T = 500
    self.control_encoder = Control_adapter(
        channels=channels, 
        nums_rb=nums_rb, 
        cin=cin)
    self._control_time_embedding =nn.Sequential(
        nn.Linear(1, control_timestep_embed_dim),
        nn.SiLU(),
        nn.Linear(control_timestep_embed_dim, control_timestep_embed_dim)
    )
    # n_layers = 5
    # input_dims = {'X': 30, 'E': 3, 'y': 1}  # X: node feature dimansion
    # hidden_mlp_dims = {'X': 256, 'E': 128, 'y': 128}
    # hidden_dims = {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
    # output_dims = {'X': 2, 'E': 3, 'y': 0}
    # self.Xdim_output = 10 #noise node feature dimansion  2d:10, 3d:15
    # self.Edim_output = 3 #edge feature dimansion  2d:3, 3d:4

    # self.model = GraphTransformer(n_layers=n_layers, #5
    #                               input_dims=input_dims, #{'X': 30, 'E': 3, 'y': 1}
    #                               hidden_mlp_dims=hidden_mlp_dims, #{'X': 256, 'E': 128, 'y': 128}
    #                               hidden_dims=hidden_dims, #{'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
    #                               output_dims=output_dims, #{'X': 2, 'E': 3, 'y': 0} X-dimansion(2,3)
    #                               act_fn_in=nn.ReLU(),
    #                               act_fn_out=nn.ReLU())

    diffusion_noise_schedule = 'cosine'
    # self.noise_schedule = PredefinedNoiseScheduleDiscrete(diffusion_noise_schedule,
    #                                                           timesteps=500)
    self.training = True
    # self.norm_values = [2, 1, 1]
    # self.norm_biases = [0, 0, 0]
    self.gamma = PredefinedNoiseSchedule(diffusion_noise_schedule, timesteps=self.T)
    # diffusion_utils.check_issues_norm_values(self.gamma, self.norm_values[1], self.norm_values[2])
    # x_marginals = torch.ones(10) #noise node feature dimansion
    # e_marginals = torch.ones(3)
    # x_marginals = x_marginals / torch.sum(x_marginals)
    # e_marginals = e_marginals / torch.sum(e_marginals)
    # self.transition_model = MarginalUniformTransition(x_marginals=x_marginals, e_marginals=e_marginals,
    #                                                           y_classes=0)
    # import pdb; pdb.set_trace()
    self._device = device
    # breakpoint()
    self.num_timesteps = 2000
    if mode == 'train':
      self.diffusion = create_gaussian_diffusion(
          steps=self.num_timesteps,
          noise_schedule='cosine',
          use_kl=False,
          predict_xstart=True,
          predict_v=False,
          rescale_timesteps=False,  #True
          rescale_learned_sigmas=False, #True
          timestep_respacing='',  #ddim8
      ) 
    else:
      self.diffusion = create_gaussian_diffusion(
          steps=self.num_timesteps,
          noise_schedule='cosine',
          use_kl=False,
          predict_xstart=True,
          predict_v=False,
          rescale_timesteps=False,  #True
          rescale_learned_sigmas=False, #True
          timestep_respacing='ddim8',  #ddim8
      ) 
    self.schedule_sampler = UniformSampler(self.diffusion)

  def forward(self):
    """Forward hook runs on class instantiation"""
    pass

  def _compute_graph_connectivity(
          self,
          node_features: torch.tensor,
          nparticles_per_example: torch.tensor,
          radius: float,
          add_self_edges: bool = True):
    """Generate graph edges to all particles within a threshold radius

    Args:
      node_features: Node features with shape (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      radius: Threshold to construct edges to all particles within the radius.
      add_self_edges: Boolean flag to include self edge (default: True)
    """
    # Specify examples id for particles
    batch_ids = torch.cat(
        [torch.LongTensor([i for _ in range(n)])
         for i, n in enumerate(nparticles_per_example)]).to(self._device)

    # radius_graph accepts r < radius not r <= radius
    # A torch tensor list of source and target nodes with shape (2, nedges)
    edge_index = radius_graph(
        node_features, r=radius, batch=batch_ids, loop=add_self_edges, max_num_neighbors=128)

    # The flow direction when using in combination with message passing is
    # "source_to_target"
    receivers = edge_index[0, :]
    senders = edge_index[1, :]

    return receivers, senders

  def _encoder_preprocessor_predict(
          self,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None):
    """Extracts important features from the position sequence. Returns a tuple
    of node_features (nparticles, 30), edge_index (nparticles, nparticles), and
    edge_features (nparticles, 3).

    Args:
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)
    """
    # import pdb; pdb.set_trace()
    nparticles = position_sequence.shape[0]
    most_recent_position = position_sequence[:, -1]  # (n_nodes, 2)
    velocity_sequence = time_diff(position_sequence)

    # Get connectivity of the graph with shape of (nparticles, 2)
    senders, receivers = self._compute_graph_connectivity(
        most_recent_position, nparticles_per_example, self._connectivity_radius)
    node_features = []
    extra_features = []

    # Normalized velocity sequence, merging spatial an time axis.
    velocity_stats = self._normalization_stats["velocity"]
    normalized_velocity_sequence = (
        velocity_sequence - velocity_stats['mean']) / velocity_stats['std']
    flat_velocity_sequence = normalized_velocity_sequence.view(
        nparticles, -1)
    # There are 5 previous steps, with dim 2
    # node_features shape (nparticles, 5 * 2 = 10)
    node_features.append(flat_velocity_sequence)
    # Normalized clipped distances to lower and upper boundaries.
    # boundaries are an array of shape [num_dimensions, 2], where the second
    # axis, provides the lower/upper boundaries.
    boundaries = torch.tensor(
        self._boundaries, requires_grad=False).float().to(self._device)
    distance_to_lower_boundary = (
        most_recent_position - boundaries[:, 0][None])
    distance_to_upper_boundary = (
        boundaries[:, 1][None] - most_recent_position)
    distance_to_boundaries = torch.cat(
        [distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
    normalized_clipped_distance_to_boundaries = torch.clamp(
        distance_to_boundaries / self._connectivity_radius,
        -self._boundary_clamp_limit, self._boundary_clamp_limit)
    # The distance to 4 boundaries (top/bottom/left/right)
    # node_features shape (nparticles, 10+4)
    # node_features.append(normalized_clipped_distance_to_boundaries)
    extra_features.append(normalized_clipped_distance_to_boundaries)
    # Particle type
    if self._nparticle_types > 1:
      particle_type_embeddings = self._particle_type_embedding(
          particle_types)
      # node_features.append(particle_type_embeddings)
      extra_features.append(particle_type_embeddings)
    # Final node_features shape (nparticles, 30) for 2D (if material_property is not valid in training example)
    # 30 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding

    # Material property
    if material_property is not None:
        material_property = material_property.view(nparticles, 1)
        node_features.append(material_property)
    # Final node_features shape (nparticles, 31) for 2D
    # 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property

    # Collect edge features.
    edge_features = []

    # Relative displacement and distances normalized to radius
    # with shape (nedges, 2)
    # normalized_relative_displacements = (
    #     torch.gather(most_recent_position, 0, senders) -
    #     torch.gather(most_recent_position, 0, receivers)
    # ) / self._connectivity_radius
    normalized_relative_displacements = (
        most_recent_position[senders, :] -
        most_recent_position[receivers, :]
    ) / self._connectivity_radius

    # Add relative displacement between two particles as an edge feature
    # with shape (nparticles, ndim)
    edge_features.append(normalized_relative_displacements)

    # Add relative distance between 2 particles with shape (nparticles, 1)
    # Edge features has a final shape of (nparticles, ndim + 1)
    normalized_relative_distances = torch.norm(
        normalized_relative_displacements, dim=-1, keepdim=True)
    edge_features.append(normalized_relative_distances)

    return (torch.cat(node_features, dim=-1),
            torch.cat(extra_features, dim=-1),
            torch.stack([senders, receivers]),
            torch.cat(edge_features, dim=-1))

  def _encoder_preprocessor(
          self,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          velocity_sequence = None,
          cotrol_time_step = None,
          material_property: torch.tensor = None):
    """Extracts important features from the position sequence. Returns a tuple
    of node_features (nparticles, 30), edge_index (nparticles, nparticles), and
    edge_features (nparticles, 3).

    Args:
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)
    """
    # import pdb; pdb.set_trace()
    nparticles = position_sequence.shape[0]
    most_recent_position = position_sequence[:, -1]  # (n_nodes, 2)
    if velocity_sequence is not None:
      velocity_sequence = velocity_sequence[:,1:]
    else:
      velocity_sequence = time_diff(position_sequence)

    # Get connectivity of the graph with shape of (nparticles, 2)
    senders, receivers = self._compute_graph_connectivity(
        most_recent_position, nparticles_per_example, self._connectivity_radius)
    node_features = []

    # Normalized velocity sequence, merging spatial an time axis.
    velocity_stats = self._normalization_stats["velocity"]
    normalized_velocity_sequence = (
        velocity_sequence - velocity_stats['mean']) / velocity_stats['std']
    flat_velocity_sequence = normalized_velocity_sequence.contiguous().view(
        nparticles, -1)
    # There are 5 previous steps, with dim 2
    # node_features shape (nparticles, 5 * 2 = 10)
    node_features.append(flat_velocity_sequence)

    # extra_features = []
    # Normalized clipped distances to lower and upper boundaries.
    # boundaries are an array of shape [num_dimensions, 2], where the second
    # axis, provides the lower/upper boundaries.
    boundaries = torch.tensor(
        self._boundaries, requires_grad=False).float().to(self._device)
    distance_to_lower_boundary = (
        most_recent_position - boundaries[:, 0][None])
    distance_to_upper_boundary = (
        boundaries[:, 1][None] - most_recent_position)
    distance_to_boundaries = torch.cat(
        [distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
    normalized_clipped_distance_to_boundaries = torch.clamp(
        distance_to_boundaries / self._connectivity_radius,
        -self._boundary_clamp_limit, self._boundary_clamp_limit)
    # The distance to 4 boundaries (top/bottom/left/right)
    # node_features shape (nparticles, 10+4)
    # node_features.append(normalized_clipped_distance_to_boundaries)
    node_features.append(normalized_clipped_distance_to_boundaries)
    
    # Particle type
    if self._nparticle_types > 1:
      particle_type_embeddings = self._particle_type_embedding(
          particle_types)
      # node_features.append(particle_type_embeddings)
      # import pdb; pdb.set_trace()
      node_features.append(particle_type_embeddings)
    # Final node_features shape (nparticles, 30) for 2D (if material_property is not valid in training example)
    # 30 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding
    if cotrol_time_step is not None:
      control_time_embedding = self._control_time_embedding(
          cotrol_time_step.float())
      node_features.append(control_time_embedding.expand(nparticles, -1))
    # Material property
    if material_property is not None:
        material_property = material_property.contiguous().view(nparticles, 1)
        # node_features.append(material_property)
        node_features.append(material_property)
    # Final node_features shape (nparticles, 31) for 2D
    # 31 = 10 (5 velocity sequences*dim) + 4 boundaries + 16 particle embedding + 1 material property

    # Collect edge features.
    # import pdb; pdb.set_trace()
    edge_features = []
    # gt_edge = []
    # Relative displacement and distances normalized to radius
    # with shape (nedges, 2)
    # normalized_relative_displacements = (
    #     torch.gather(most_recent_position, 0, senders) -
    #     torch.gather(most_recent_position, 0, receivers)
    # ) / self._connectivity_radius
    normalized_relative_displacements = (
        most_recent_position[senders, :] -
        most_recent_position[receivers, :]
    ) / self._connectivity_radius

    # normalized_relative_displacements_gt = (
    #     next_positions[senders, :] -
    #     next_positions[receivers, :]
    # ) / self._connectivity_radius
    # Add relative displacement between two particles as an edge feature
    # with shape (nparticles, ndim)
    edge_features.append(normalized_relative_displacements)
    # gt_edge.append(normalized_relative_displacements_gt)

    # Add relative distance between 2 particles with shape (nparticles, 1)
    # Edge features has a final shape of (nparticles, ndim + 1)
    normalized_relative_distances = torch.norm(
        normalized_relative_displacements, dim=-1, keepdim=True)
    # normalized_relative_distances_gt = torch.norm(
    #     normalized_relative_displacements_gt, dim=-1, keepdim=True)

    edge_features.append(normalized_relative_distances)
    # gt_edge.append(normalized_relative_distances_gt)

    return (torch.cat(node_features, dim=-1), #torch.cat(extra_features, dim=-1),
            torch.stack([senders, receivers]),
            torch.cat(edge_features, dim=-1))

  def _decoder_postprocessor(
          self,
          normalized_acceleration: torch.tensor,
          position_sequence: torch.tensor) -> torch.tensor:
    """ Compute new position based on acceleration and current position.
    The model produces the output in normalized space so we apply inverse
    normalization.

    Args:
      normalized_acceleration: Normalized acceleration (nparticles, dim).
      position_sequence: Position sequence of shape (nparticles, dim).

    Returns:
      torch.tensor: New position of the particles.

    """
    # Extract real acceleration values from normalized values
    acceleration_stats = self._normalization_stats["acceleration"]
    acceleration = (
        normalized_acceleration * acceleration_stats['std']
    ) + acceleration_stats['mean']

    # predict velocity
    # velocity_stats = self._normalization_stats["velocity"]
    # velocity = (
    #     normalized_acceleration * velocity_stats['mean']
    # ) + velocity_stats['std']
    # Use an Euler integrator to go from acceleration to position, assuming
    # a dt=1 corresponding to the size of the finite difference.
    most_recent_position = position_sequence[:, -1]
    most_recent_velocity = most_recent_position - position_sequence[:, -2]

    # TODO: Fix dt
    new_velocity = most_recent_velocity + acceleration  # * dt = 1
    # new_velocity = velocity
    new_position = most_recent_position + new_velocity  # * dt = 1
    return new_position

  def predict_positions(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None,
          batch: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    import pdb; pdb.set_trace()
    if material_property is not None:
        node_features, extra_features, edge_index, edge_features = self._encoder_preprocessor_predict(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, extra_features, edge_index, edge_features = self._encoder_preprocessor_predict(
            current_positions, nparticles_per_example, particle_types)
    
    # Digress
    self.training = False
    # import pdb; pdb.set_trace()
    # y = torch.empty(1, 0, device='cuda') #batch_size
    # y = torch.empty(1, 0, device='cuda' if torch.cuda.is_available() else 'cpu')
    indices = list(range(self.num_timesteps))[::-1]
    acc = torch.randn((node_features.shape[0], 2), device=self._device)
    sample_fn = self.diffusion.ddim_sample
    for i in indices:
      t = torch.tensor([i], device=self._device)
      out = sample_fn(
            self._encode_process_decode,
            acc,
            t,
            node_features,
            edge_index,
            edge_features,
            clip_denoised=False
      )
      # yield out
      acc = out["sample"]
    # dense_data, node_mask, edge_mask = utils.to_dense(node_features, edge_index, edge_features, batch)
    # dense_data = dense_data.mask(node_mask)
    # X, E = dense_data.X, dense_data.E
    # normalized_data = utils.normalize(X, E, y, self.norm_values, self.norm_biases, node_mask)
    # noisy_data = self.apply_noise(normalized_data.X, normalized_data.E, normalized_data.y, node_mask)

    # X_noise, edge_index, E_noise, batch_noise = utils.from_dense(noisy_data['X_t'], noisy_data['E_t'], node_mask, edge_mask, batch_size=1)
    # # extra_features, _ = to_dense_batch(x=extra_features, batch=batch)
    # X_all = torch.cat((X_noise, extra_features), dim=1).float()
    # Y = torch.hstack((noisy_data['y_t'], noisy_data['t'])).float() 
    # # import pdb; pdb.set_trace()
    # predicted_normalized_acceleration = self._encode_process_decode(X_all, edge_index, E_noise.float(), Y)

    # # import pdb; pdb.set_trace()
    # unique_batches = batch.unique() 
    # batch_sizes = torch.bincount(batch) 
    # output = []
    # for batch_idx in unique_batches:
    #   batch_data = pred.X[batch_idx, :batch_sizes[batch_idx], :]
    #   output.append(batch_data)
    
    # predicted_normalized_velocity = torch.cat(output, dim=0)

    # predicted_normalized_acceleration = self._encode_process_decode(
    #     node_features, edge_index, edge_features)
    
    next_positions = self._decoder_postprocessor(
        acc, current_positions)
    
    return next_positions

  def predict_positions_light(
          self,
          current_positions: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None,
          velocity_sequence: torch.tensor = None,
          control_img: torch.tensor = None,
          cotrol_time_step: torch.tensor = None,
          batch: torch.tensor = None) -> torch.tensor:
    """Predict position based on acceleration.

    Args:
      current_positions: Current particle positions (nparticles, dim).
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles)

    Returns:
      next_positions (torch.tensor): Next position of particles.
    """
    # import pdb; pdb.set_trace()
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            current_positions, nparticles_per_example, particle_types, velocity_sequence=velocity_sequence, cotrol_time_step=cotrol_time_step) 
    # Digress
    self.training = False
    if control_img is not None:
      control_feat = self.control_encoder(control_img.unsqueeze(0))
    sample_fn = self.diffusion.ddim_sample_loop
    sample = sample_fn(
          self._encode_process_decode,
          (1, node_features.shape[0], current_positions.shape[-1]),
          node_features,
          edge_index,
          edge_features,
          control_feat=control_feat,
          clip_denoised=False
      )
    # indices = list(range(self.num_timesteps))[::-1]
    # acc = torch.randn((node_features.shape[0], 2), device=self._device)
    # sample_fn = self.diffusion.ddim_sample
    # for i in indices:
    #    t = torch.tensor([i], device=self._device)
    #    with torch.no_grad():
    #       out = sample_fn(
    #           self._encode_process_decode,
    #           acc,
    #           t,
    #           node_features,
    #           edge_index,
    #           edge_features
    #       )
    #       # yield out
    #       acc = out["sample"]
    # import pdb; pdb.set_trace()
    # next_positions = self._decoder_postprocessor(
    #     sample.squeeze(0), current_positions)

    acceleration_stats = self._normalization_stats["acceleration"]
    acceleration = (
        sample * acceleration_stats['std']
    ) + acceleration_stats['mean']
    acceleration = acceleration / (self.dt**2)
    return acceleration, sample

  def predict_accelerations(
          self,
          next_positions: torch.tensor,
          position_sequence_noise: torch.tensor,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          material_property: torch.tensor = None,
          batch: torch.tensor = None):
    """Produces normalized and predicted acceleration targets.

    Args:
      next_positions: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence_noise: Tensor of the same shape as `position_sequence`
        with the noise to apply to each particle.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles).

    Returns:
      Tensors of shape (nparticles_in_batch, dim) with the predicted and target
        normalized accelerations.

    """

    # Add noise to the input position sequence.
    # import pdb; pdb.set_trace()
    noisy_position_sequence = position_sequence + position_sequence_noise

    # Perform the forward pass with the noisy position sequence.
    if material_property is not None:
        node_features, extra_features, edge_index, edge_features, gt_edge = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, next_positions, material_property)
    else:
        node_features, extra_features, edge_index, edge_features, gt_edge = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, next_positions)    # node_features([1356, 30]) particle-number, feature
    
    # Digress
    self.training = True
    # import pdb; pdb.set_trace()
    y = torch.empty(1, 0, device='cuda' if torch.cuda.is_available() else 'cpu') #batch_size
    dense_data, node_mask, edge_mask = utils.to_dense(node_features, edge_index, edge_features, batch)
    dense_data = dense_data.mask(node_mask)
    X, E = dense_data.X, dense_data.E
    normalized_data = utils.normalize(X, E, y, self.norm_values, self.norm_biases, node_mask)
    noisy_data = self.apply_noise(normalized_data.X, normalized_data.E, normalized_data.y, node_mask)
    
    X_noise, edge_index, E_noise, batch_noise = utils.from_dense(noisy_data['X_t'], noisy_data['E_t'], node_mask, edge_mask, batch_size=1)
    # assert edge_index==edge_index_noise, "from_dense index error"
    # assert batch_noise==batch, "from_dense batch error"

    # dense_extra, _, _ = utils.to_dense(extra_features, edge_index, gt_edge, batch)
    # extra_features = extra_features.mask(node_mask)
    X_all = torch.cat((X_noise, extra_features), dim=1).float()
    Y = torch.hstack((noisy_data['y_t'], noisy_data['t'])).float() 
    predicted_normalized_acceleration = self._encode_process_decode(X_all, edge_index, E_noise.float(), Y)

    # # import pdb; pdb.set_trace()
    # unique_batches = batch.unique() 
    # batch_sizes = torch.bincount(batch) 
    # output = []
    # for batch_idx in unique_batches:
    #   batch_data = pred.X[batch_idx, :batch_sizes[batch_idx], :]
    #   output.append(batch_data)
    
    # predicted_normalized_velocity = torch.cat(output, dim=0)
    # predicted_normalized_acceleration = self._encode_process_decode(
    #     node_features, edge_index, edge_features)
    # Calculate the target acceleration, using an `adjusted_next_position `that
    # is shifted by the noise in the last input position.
    next_position_adjusted = next_positions + position_sequence_noise[:, -1]
    target_normalized_acceleration, target_normalized_velocity = self._inverse_decoder_postprocessor(
        next_position_adjusted, noisy_position_sequence)
    # As a result the inverted Euler update in the `_inverse_decoder` produces:
    # * A target acceleration that does not explicitly correct for the noise in
    #   the input positions, as the `next_position_adjusted` is different
    #   from the true `next_position`.
    # * A target acceleration that exactly corrects noise in the input velocity
    #   since the target next velocity calculated by the inverse Euler update
    #   as `next_position_adjusted - noisy_position_sequence[:,-1]`
    #   matches the ground truth next velocity (noise cancels out).

    return predicted_normalized_acceleration, target_normalized_acceleration

  def predict_accelerations_light(
          self,
          next_positions: torch.tensor,
          position_sequence_noise: torch.tensor,
          position_sequence: torch.tensor,
          nparticles_per_example: torch.tensor,
          particle_types: torch.tensor,
          velocity_sequence: torch.tensor = None,
          acceleration_sequence:  torch.tensor = None,
          control_img: torch.tensor = None,
          cotrol_time_step: torch.tensor = None,
          material_property: torch.tensor = None,
          batch: torch.tensor = None):
    """Produces normalized and predicted acceleration targets.

    Args:
      next_positions: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence_noise: Tensor of the same shape as `position_sequence`
        with the noise to apply to each particle.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.
      nparticles_per_example: Number of particles per example. Default is 2
        examples per batch.
      particle_types: Particle types with shape (nparticles).
      material_property: Friction angle normalized by tan() with shape (nparticles).

    Returns:
      Tensors of shape (nparticles_in_batch, dim) with the predicted and target
        normalized accelerations.

    """

    # Add noise to the input position sequence.
    # import pdb; pdb.set_trace()
    noisy_position_sequence = position_sequence + position_sequence_noise

    # Perform the forward pass with the noisy position sequence.
    if material_property is not None:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, material_property)
    else:
        node_features, edge_index, edge_features = self._encoder_preprocessor(
            noisy_position_sequence, nparticles_per_example, particle_types, velocity_sequence=velocity_sequence, cotrol_time_step=cotrol_time_step)    # node_features([1356, 30]) particle-number, feature
    # Digress
    # self.training = True
    # import pdb; pdb.set_trace()
    # y = torch.empty(1, 0, device='cuda' if torch.cuda.is_available() else 'cpu') #batch_size
    # edge_index, edge_features = torch_geometric.utils.remove_self_loops(edge_index, edge_features)
    # dense_data, node_mask, edge_mask = utils.to_dense(node_features, edge_index, edge_features, batch)
    # dense_data = dense_data.mask(node_mask)
    # X, E = dense_data.X, dense_data.E
    # X, node_mask = to_dense_batch(x=node_features, batch=batch)
    # E = edge_features.unsqueeze(0)

    # normalized_data = utils.normalize_light(X, E, y, self.norm_values, self.norm_biases, node_mask)
    next_position_adjusted = next_positions + position_sequence_noise[:, -1]
    # target_normalized_acceleration = self._inverse_decoder_postprocessor(
    #     next_position_adjusted, noisy_position_sequence)
    target_normalized_acceleration = self._inverse_decoder_postprocessor(
        next_position_adjusted, noisy_position_sequence, velocity_sequence=velocity_sequence, acceleration_sequence=acceleration_sequence)
    # target_normalized_acceleration = self._inverse_decoder_postprocessor(
    #     next_position_adjusted, noisy_position_sequence)

    t, weights = self.schedule_sampler.sample(1, self._device) #batch
    
    if control_img is not None:
      control_feat = self.control_encoder(control_img.unsqueeze(0))
    else:
      control_feat = None      
    mse_loss = self.diffusion.training_losses(
        self._encode_process_decode, 
        target_normalized_acceleration.unsqueeze(0), 
        t, 
        node_features, 
        edge_index, 
        edge_features,
        control_feat)
    loss = (mse_loss["loss"] * weights).mean()
    # X = target_normalized_acceleration.unsqueeze(0)  # X0

    # noisy_acceleration = self.apply_noise_light(X, y)
    
    # X_noise, edge_index, E_noise, batch_noise = utils.from_dense(noisy_data['X_t'], noisy_data['E_t'], node_mask, edge_mask, batch_size=1)
    # assert edge_index==edge_index_noise, "from_dense index error"
    # assert batch_noise==batch, "from_dense batch error"
    # X_noise = noisy_acceleration['X_t'].squeeze(0)
    # # E_noise = noisy_data['E_t'].squeeze(0)
    # # dense_extra, _, _ = utils.to_dense(extra_features, edge_index, gt_edge, batch)
    # # extra_features = extra_features.mask(node_mask)
    # X_all = torch.cat((node_features, X_noise), dim=1).float()
    # X_eps = noisy_acceleration['epsX'].squeeze(0)
    # Y = torch.hstack((noisy_acceleration['y_t'], noisy_acceleration['t'])).float() 
    # predicted_normalized_acceleration = self._encode_process_decode(X_all, edge_index, edge_features, Y)

    return loss #target_normalized_acceleration

  def _inverse_decoder_postprocessor(
        self,
        next_position: torch.tensor,
        position_sequence: torch.tensor,
        velocity_sequence: torch.tensor = None,
        acceleration_sequence: torch.tensor = None):
    """Inverse of `_decoder_postprocessor`.

    Args:
      next_position: Tensor of shape (nparticles_in_batch, dim) with the
        positions the model should output given the inputs.
      position_sequence: A sequence of particle positions. Shape is
        (nparticles, 6, dim). Includes current + last 5 positions.
      velocity_sequence: Optional sequence of particle velocities. Shape is
        (nparticles, 6, dim).
      acceleration_sequence: Optional sequence of particle accelerations. Shape is
        (nparticles, 6, dim).

    Returns:
      normalized_acceleration (torch.tensor): Normalized acceleration.

    """
    if acceleration_sequence is not None:
        # If we have acceleration sequence, use the most recent one
        next_acceleration = acceleration_sequence[:, -1]
    else:
        # Otherwise compute from velocity or position sequences
        if velocity_sequence is not None:
            previous_velocity = velocity_sequence[:, -1]
            next_velocity = velocity_sequence[:, -2]  # Assuming this is the predicted next velocity
            next_acceleration = next_velocity - previous_velocity
        else:
            # Fall back to original position-based calculation
            previous_position = position_sequence[:, -1]
            previous_velocity = previous_position - position_sequence[:, -2]
            next_velocity = next_position - previous_position
            next_acceleration = next_velocity - previous_velocity

    # Normalize the acceleration
    acceleration_stats = self._normalization_stats["acceleration"]
    normalized_acceleration = (
        next_acceleration - acceleration_stats['mean']) / acceleration_stats['std']
    
    return normalized_acceleration

  def save(
          self,
          path: str = 'model.pt'):
    """Save model state

    Args:
      path: Model path
    """
    torch.save(self.state_dict(), path)

  def load(
          self,
          path: str):
    """Load model state from file

    Args:
      path: Model path
    """
    self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)

  # def apply_noise(self, X, E, y, node_mask):
  #   """ Sample noise and apply it to the data. """
  #   # Sample a timestep t.
  #   # When evaluating, the loss for t=0 is computed separately
  #   # import pdb; pdb.set_trace()
  #   lowest_t = 0 if self.training else 1
  #   t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
  #   s_int = t_int - 1

  #   t_float = t_int / self.T
  #   s_float = s_int / self.T

  #   # beta_t and alpha_s_bar are used for denoising/loss computation
  #   beta_t = self.noise_schedule(t_normalized=t_float)                         # (bs, 1)
  #   alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)      # (bs, 1)
  #   alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)      # (bs, 1)

  #   Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device='cpu')  # (bs, dx_in, dx_out), (bs, de_in, de_out)
  #   assert (abs(Qtb.X.sum(dim=2) - 1.) < 1e-4).all(), Qtb.X.sum(dim=2) - 1
  #   assert (abs(Qtb.E.sum(dim=2) - 1.) < 1e-4).all()

  #   # Compute transition probabilities
  #   probX = X @ Qtb.X  # (bs, n, dx_out)
  #   probE = E @ Qtb.E.unsqueeze(1)  # (bs, n, n, de_out)

  #   sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)

  #   # import pdb; pdb.set_trace()
  #   X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
  #   E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
  #   assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

  #   z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)

  #   noisy_data = {'t_int': t_int, 't': t_float, 'beta_t': beta_t, 'alpha_s_bar': alpha_s_bar,
  #                 'alpha_t_bar': alpha_t_bar, 'X_t': z_t.X, 'E_t': z_t.E, 'y_t': z_t.y, 'node_mask': node_mask}
  #   return noisy_data  

  def apply_noise(self, X, E, y, node_mask):
    """ Sample noise and apply it to the data. """
    # When evaluating, the loss for t=0 is computed separately
    lowest_t = 0 if self.training else 1

    # Sample a timestep t.
    t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1))
    t_int = t_int.type_as(X).float()  # (bs, 1)
    s_int = t_int - 1

    # Normalize t to [0, 1]. Note that the negative
    # step of s will never be used, since then p(x | z0) is computed.
    s_normalized = s_int / self.T
    t_normalized = t_int / self.T

    # Compute gamma_s and gamma_t via the network.
    gamma_s = diffusion_utils.inflate_batch_array(self.gamma(s_normalized), X.size())    # (bs, 1, 1),
    gamma_t = diffusion_utils.inflate_batch_array(self.gamma(t_normalized), X.size())    # (bs, 1, 1)

    # Compute alpha_t and sigma_t from gamma, with correct size for X, E and z
    alpha_t = diffusion_utils.alpha(gamma_t, X.size())                        # (bs, 1, ..., 1), same n_dims than X
    sigma_t = diffusion_utils.sigma(gamma_t, X.size())                        # (bs, 1, ..., 1), same n_dims than X

    # Sample zt ~ Normal(alpha_t x, sigma_t)
    eps = diffusion_utils.sample_feature_noise(X.size(), E.size(), y.size(), node_mask).type_as(X)

    # Sample z_t given x, h for timestep t, from q(z_t | x, h)
    X_t = alpha_t * X + sigma_t * eps.X
    E_t = alpha_t.unsqueeze(1) * E + sigma_t.unsqueeze(1) * eps.E
    y_t = alpha_t.squeeze(1) * y + sigma_t.squeeze(1) * eps.y

    noisy_data = {'t': t_normalized, 's': s_normalized, 'gamma_t': gamma_t, 'gamma_s': gamma_s,
                  'epsX': eps.X, 'epsE': eps.E, 'epsy': eps.y,
                  'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't_int': t_int}

    return noisy_data
  
  def apply_noise_light(self, X, y):
    """ Sample noise and apply it to the data. """
    # When evaluating, the loss for t=0 is computed separately
    lowest_t = 0 if self.training else 1
    # lowest_t = 0 if self.training else self.T

    # import pdb; pdb.set_trace()
    # Sample a timestep t.
    # fixed_value = 100
    # t_int = torch.tensor([[fixed_value]]).expand(X.size(0), 1)
    t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1))
    t_int = t_int.type_as(X).float()  # (bs, 1)
    s_int = t_int - 1

    # Normalize t to [0, 1]. Note that the negative
    # step of s will never be used, since then p(x | z0) is computed.
    s_normalized = s_int / self.T
    t_normalized = t_int / self.T

    # Compute gamma_s and gamma_t via the network.
    gamma_s = diffusion_utils.inflate_batch_array(self.gamma(s_normalized), X.size())    # (bs, 1, 1),
    gamma_t = diffusion_utils.inflate_batch_array(self.gamma(t_normalized), X.size())    # (bs, 1, 1)

    # Compute alpha_t and sigma_t from gamma, with correct size for X, E and z
    alpha_t = diffusion_utils.alpha(gamma_t, X.size())                        # (bs, 1, ..., 1), same n_dims than X
    sigma_t = diffusion_utils.sigma(gamma_t, X.size())                        # (bs, 1, ..., 1), same n_dims than X

    # import pdb; pdb.set_trace()
    # Sample zt ~ Normal(alpha_t x, sigma_t)
    eps_X, eps_y = diffusion_utils.sample_acc_noise_light(X.size(), y.size())

    # Sample z_t given x, h for timestep t, from q(z_t | x, h)
    X_t = alpha_t * X + sigma_t * eps_X
    y_t = alpha_t.squeeze(1) * y + sigma_t.squeeze(1) * eps_y

    noisy_data = {'t': t_normalized, 's': s_normalized, 'gamma_t': gamma_t, 'gamma_s': gamma_s,
                  'epsX': eps_X, 'epsy': eps_y,
                  'X_t': X_t, 'y_t': y_t, 't_int': t_int}

    return noisy_data
  
  def compute_extra_data(self, noisy_data):
    """ At every training step (after adding noise) and step in sampling, compute extra information and append to
        the network input. """

    extra_features = self.extra_features(noisy_data)
    extra_molecular_features = self.domain_features(noisy_data)

    extra_X = torch.cat((extra_features.X, extra_molecular_features.X), dim=-1)
    extra_E = torch.cat((extra_features.E, extra_molecular_features.E), dim=-1)
    extra_y = torch.cat((extra_features.y, extra_molecular_features.y), dim=-1)

    t = noisy_data['t']
    extra_y = torch.cat((extra_y, t), dim=1)

    return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)

  def sample_p_zs_given_zt(self, s, t, X_t, X_con, E_index, E_con, y_con):
        """Samples from zs ~ p(zs | zt). Only used during sampling."""
        gamma_s = diffusion_utils.inflate_batch_array(self.gamma(s), X_t.size())    # (bs, 1, 1),
        gamma_t = diffusion_utils.inflate_batch_array(self.gamma(t), X_t.size())    # (bs, 1, 1)

        sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = diffusion_utils.sigma_and_alpha_t_given_s(gamma_t,
                                                                                                       gamma_s,
                                                                                                       X_t.size())
        sigma_s = diffusion_utils.sigma(gamma_s, target_shape=X_t.size())
        sigma_t = diffusion_utils.sigma(gamma_t, target_shape=X_t.size())

        # E_t = (E_t + E_t.transpose(1, 2)) / 2
        # noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t}
        # extra_data = self.compute_extra_data(noisy_data)
        # import pdb; pdb.set_trace()
        X_all = torch.cat((X_con, X_t.squeeze(0)), dim=1).float()
        Y = torch.hstack((y_con, t)).float() 
        # eps = self._encode_process_decode(X_all, E_index, E_con, Y)

        # # Compute mu for p(zs | zt).
        # mu_X = X_t / alpha_t_given_s - (sigma2_t_given_s / (alpha_t_given_s * sigma_t)) * (eps.unsqueeze(0))
        # predict x_0
        x_0 = self._encode_process_decode(X_all, E_index, E_con, Y)
        mu_X = alpha_t_given_s * x_0 + sigma_t_given_s * X_t / sigma_t
        # mu_X = (X_t - sigma_t * x_0) / alpha_t_given_s
        # mu_E = E_t / alpha_t_given_s.unsqueeze(1) - (sigma2_t_given_s / (alpha_t_given_s * sigma_t)).unsqueeze(1) * eps.E
        # mu_y = y_con / alpha_t_given_s.squeeze(1) - (sigma2_t_given_s / (alpha_t_given_s * sigma_t)).squeeze(1) * eps.y

        # Compute sigma for p(zs | zt).
        sigma = sigma_t_given_s * sigma_s / sigma_t

        # Sample zs given the parameters derived from zt.
        z_s, y = diffusion_utils.sample_normal_acc(mu_X, y_con, sigma)

        return z_s

  def sample_continuous_graph_given_z0(self, X_0, X_con, E_index, E_con, y_0):
        """ Samples X, E, y ~ p(X, E, y|z0): once the diffusion is done, we need to map the result
        to categorical values.
        """
        # import pdb; pdb.set_trace()
        zeros = torch.zeros(size=(X_0.size(0), 1), device=X_0.device)
        gamma_0 = self.gamma(zeros)
        # Computes sqrt(sigma_0^2 / alpha_0^2)
        sigma = diffusion_utils.SNR(-0.5 * gamma_0).unsqueeze(1)

        t = torch.zeros(y_0.shape[0], 1).type_as(y_0)
        X_all = torch.cat((X_con, X_0.squeeze(0)), dim=1).float()
        Y = torch.hstack((y_0, t)).float() 
        # eps0 = self._encode_process_decode(X_all, E_index, E_con, Y)
        # # Compute mu for p(zs | zt).
        # sigma_0 = diffusion_utils.sigma(gamma_0, target_shape=eps0.size())
        # alpha_0 = diffusion_utils.alpha(gamma_0, target_shape=eps0.size())
        # pred_X = 1. / alpha_0 * (X_0 - sigma_0 * eps0)

        x0_pred = self._encode_process_decode(X_all, E_index, E_con, Y)  #  x0
        # Compute mu for p(zs | zt).
        sigma_0 = diffusion_utils.sigma(gamma_0, target_shape=x0_pred.size())
        alpha_0 = diffusion_utils.alpha(gamma_0, target_shape=x0_pred.size())
        pred_X = x0_pred

        # sampled, y_0 = diffusion_utils.sample_normal_acc(pred_X, y_0, sigma)

        # sampled = utils.unnormalize(sampled.X, sampled.E, sampled.y, self.norm_values,
        #                             self.norm_biases, node_mask, collapse=True)
        return pred_X
  
def time_diff(
        position_sequence: torch.tensor) -> torch.tensor:
  """Finite difference between two input position sequence

  Args:
    position_sequence: Input position sequence & shape(nparticles, 6 steps, dim)

  Returns:
    torch.tensor: Velocity sequence
  """
  return position_sequence[:, 1:] - position_sequence[:, :-1]
