import collections
import json
import os
import pickle
import glob
import re
import sys

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

from absl import flags
from absl import app

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from gns_mpm import learned_simulator
from gns_mpm import noise_utils
from gns_mpm import reading_utils
from gns_mpm import data_loader
from gns_mpm import distribute

import matplotlib.pyplot as plt
from p2g_utils import run_simulation, run_simulation_mpm #, run_simulation_mpm_thershold
from p2g_utils import run_simulation_mpm_thershold
import time

flags.DEFINE_enum(
    'mode', 'train', ['train', 'valid', 'rollout'],
    help='Train model, validation or rollout evaluation.')
flags.DEFINE_integer('batch_size', 1, help='The batch size.')
flags.DEFINE_float('noise_std', 6.7e-4, help='The std deviation of the noise.')
flags.DEFINE_string('data_path', None, help='The dataset directory.')
flags.DEFINE_string('data_path_dense', None, help='The dataset directory.')
flags.DEFINE_string('model_path', 'models/', help=('The path for saving checkpoints of the model.'))
flags.DEFINE_string('output_path', 'rollouts/', help='The path for saving outputs (e.g. rollouts).')
flags.DEFINE_string('output_filename', 'rollout', help='Base name for saving the rollout')
flags.DEFINE_string('model_file', None, help=('Model filename (.pt) to resume from. Can also use "latest" to default to newest file.'))
flags.DEFINE_string('train_state_file', 'train_state.pt', help=('Train state filename (.pt) to resume from. Can also use "latest" to default to newest file.'))

flags.DEFINE_integer('ntraining_steps', int(2E7), help='Number of training steps.')
flags.DEFINE_integer('validation_interval', None, help='Validation interval. Set `None` if validation loss is not needed')
flags.DEFINE_integer('nsave_steps', int(100000), help='Number of steps at which to save the model.')

# Learning rate parameters
flags.DEFINE_float('lr_init', 1e-4, help='Initial learning rate.')
flags.DEFINE_float('lr_decay', 0.1, help='Learning rate decay.')
flags.DEFINE_integer('lr_decay_steps', int(5e6), help='Learning rate decay steps.')

flags.DEFINE_integer("cuda_device_number", None, help="CUDA device (zero indexed), default is None so default CUDA device will be used.")
flags.DEFINE_integer("n_gpus", 1, help="The number of GPUs to utilize for training.")

FLAGS = flags.FLAGS

Stats = collections.namedtuple('Stats', ['mean', 'std'])

INPUT_SEQUENCE_LENGTH = 6  # So we can calculate the last 5 velocities.
NUM_PARTICLE_TYPES = 9
KINEMATIC_PARTICLE_ID = 3

from bolt.upsample_2d import random_sampling, octree_initial
from bolt.utils import knn_smoothing, bilateral_smoothing

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

class ParticleReducer:
    def __init__(self, target_particles=256, random_state=42):

        self.n_clusters = target_particles
        self.kmeans = KMeans(
            n_clusters=target_particles,
            init='k-means++',
            n_init=10,
            random_state=random_state
        )
        self.scaler = StandardScaler()
        
    def reduce_system(self, positions, velocities):
        
        
        # breakpoint()
        features = np.hstack([positions, velocities])
        scaled_features = self.scaler.fit_transform(features)
        
        
        self.kmeans.fit(scaled_features)
        
        
        cluster_centers_scaled = self.kmeans.cluster_centers_
        cluster_centers = self.scaler.inverse_transform(cluster_centers_scaled)
        dim = positions.shape[1]
        new_positions = cluster_centers[:, :dim]
        
        
        new_velocities = np.zeros_like(new_positions)
        for c in range(self.n_clusters):
            cluster_mask = self.kmeans.labels_ == c
            if np.any(cluster_mask):
                new_velocities[c] = np.mean(velocities[cluster_mask], axis=0)
                
        return new_positions, new_velocities

class ParticleReducerTorch:
    def __init__(self, target_particles=256, max_iter=20, tol=1e-4, device='cuda'):
        self.n_clusters = target_particles
        self.max_iter = max_iter
        self.tol = tol
        self.device = torch.device(device)

    def reduce_system(self, positions, velocities):
        """
        positions: [N, D]
        velocities: [N, D]
        return: new_positions [K, D], new_velocities [K, D]
        """
        features = torch.cat([positions, velocities], dim=1).to(self.device)  # [N, 2D]
        N, D = features.shape

        # ： K 
        indices = torch.randperm(N)[:self.n_clusters]
        centers = features[indices]

        for i in range(self.max_iter):
            #  [N, K]
            dists = torch.cdist(features, centers)
            labels = dists.argmin(dim=1)

            # 
            new_centers = torch.stack([
                features[labels == k].mean(dim=0) if (labels == k).any() else centers[k]
                for k in range(self.n_clusters)
            ])

            shift = (centers - new_centers).pow(2).sum()
            centers = new_centers

            if shift < self.tol:
                break

        #  position / velocity
        new_positions = centers[:, :positions.shape[1]]
        new_velocities = centers[:, positions.shape[1]:]
        return new_positions, new_velocities

class ParticleReducerTorchLite:
    def __init__(self, target_particles=256, max_iter=10, tol=1e-4, device='cuda'):
        self.n_clusters = target_particles
        self.max_iter = max_iter
        self.tol = tol
        self.device = torch.device(device)

    def reduce_system(self, positions):
        """
        positions: [N, D]
        return: new_positions [K, D]
        """
        # positions = positions.to(self.device)
        N, D = positions.shape

        #  K 
        indices = torch.randperm(N)[:self.n_clusters]
        centers = positions[indices]

        for _ in range(self.max_iter):
            #  [N, K] 
            dists = torch.cdist(positions, centers, p=2)
            labels = dists.argmin(dim=1)

            # 
            new_centers = torch.stack([
                positions[labels == k].mean(dim=0) if (labels == k).any() else centers[k]
                for k in range(self.n_clusters)
            ])

            # 
            shift = (centers - new_centers).pow(2).sum()
            centers = new_centers

            if shift < self.tol:
                break

        return centers

class ParticleReducerTorchBatch:
    def __init__(self, target_particles=256, max_iter=10, tol=1e-4, device='cuda'):
        self.n_clusters = target_particles
        self.max_iter = max_iter
        self.tol = tol
        self.device = torch.device(device)

    def reduce_system(self, positions):
        """
        positions: [B, N, D]
        return: new_positions: [B, K, D]
        """
        B, N, D = positions.shape
        positions = positions.to(self.device)

        # ：（ batch ）
        rand_idx = torch.randint(N, (B, self.n_clusters), device=self.device)
        centers = torch.gather(positions, 1, rand_idx.unsqueeze(-1).expand(-1, -1, D))

        for _ in range(self.max_iter):
            #  [B, N, K] （）
            pos_sq = (positions ** 2).sum(dim=2, keepdim=True)       # [B, N, 1]
            cen_sq = (centers ** 2).sum(dim=2).unsqueeze(1)          # [B, 1, K]
            dists = pos_sq - 2 * torch.matmul(positions, centers.transpose(1, 2)) + cen_sq  # [B, N, K]

            labels = dists.argmin(dim=2)  # [B, N]

            # 
            new_centers = []
            for b in range(B):
                c_list = []
                for k in range(self.n_clusters):
                    mask = (labels[b] == k)
                    if mask.any():
                        c = positions[b][mask].mean(dim=0)
                    else:
                        c = centers[b, k]
                    c_list.append(c)
                new_centers.append(torch.stack(c_list))
            new_centers = torch.stack(new_centers)

            shift = (centers - new_centers).pow(2).sum()
            centers = new_centers

            if shift < self.tol:
                break

        return centers  # [B, K, D]

def compute_latency_ms_pytorch(model, current_positions, n_particles_per_example, particle_types, material_property, iterations=None, device=None):
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    model.eval()
    model = model.cuda()

    # input = torch.randn(*input_size).cuda()

    with torch.no_grad():
        for _ in range(10):
            # model(input)
            next_position, next_velocity, next_acc = model.predict_positions(
                  current_positions,
                  nparticles_per_example=[n_particles_per_example],
                  particle_types=particle_types,
                  material_property=material_property
              )

        if iterations is None:
            elapsed_time = 0
            iterations = 100
            while elapsed_time < 1:
                torch.cuda.synchronize()
                torch.cuda.synchronize()
                t_start = time.time()
                for _ in range(iterations):
                    next_position, next_velocity, next_acc = model.predict_positions(
                        current_positions,
                        nparticles_per_example=[n_particles_per_example],
                        particle_types=particle_types,
                        material_property=material_property
                    )
                torch.cuda.synchronize()
                torch.cuda.synchronize()
                elapsed_time = time.time() - t_start
                iterations *= 2
            FPS = iterations / elapsed_time
            iterations = int(FPS * 6)

        print('=========Speed Testing=========')
        torch.cuda.synchronize()
        torch.cuda.synchronize()
        t_start = time.time()
        for _ in tqdm(range(iterations)):
            next_position, next_velocity, next_acc = model.predict_positions(
                  current_positions,
                  nparticles_per_example=[n_particles_per_example],
                  particle_types=particle_types,
                  material_property=material_property
              )
            
        torch.cuda.synchronize()
        torch.cuda.synchronize()
        elapsed_time = time.time() - t_start
        latency = elapsed_time / iterations * 1000
    torch.cuda.empty_cache()
    # FPS = 1000 / latency (in ms)
    return latency

def plot_figure(first_positions, new_pos, index):
       #  ()
    plt.subplot(1, 2, 1)
    plt.scatter(first_positions[:, 0], first_positions[:, 1], c='blue', s=1)
    plt.title('Original Points')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.xlim(0, 1)  
    plt.ylim(0, 1)  

    #  ()
    plt.subplot(1, 2, 2)
    plt.scatter(new_pos[:, 0], new_pos[:, 1], c='red', s=1)
    plt.title('Sampled Points')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.tight_layout()
    plt.xlim(0, 1)  
    plt.ylim(0, 1) 

    output_file_path = f"./thershold_gns_mpm/scatter_plot_left_right_{index}.png"
    plt.savefig(output_file_path)

from torch_cluster import fps

def sample_points(x, num_samples):
    """
    x: (n, dim) tensor
    num_samples: 
    : (num_samples, dim) tensor
    """
    batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
    ratio = (num_samples/x.shape[0])
    idx = fps(x, batch, ratio=ratio)
    return x[idx]

def grid_sample_2d(points, grid_size=0.05):
    # points: [N, 2]
    coords = (points / grid_size).floor().int()
    unique_coords, indices = torch.unique(coords, dim=0, return_inverse=True)

    new_points = []
    for i in range(unique_coords.shape[0]):
        mask = (indices == i)
        new_points.append(points[mask].mean(dim=0))
    return torch.stack(new_points)

def random_knn(
        position_sparse: torch.tensor,
        n_initial_sample: int,
        n_final_sample: int):
  position_sparse = position_sparse.cpu().numpy()
  new_pts = random_sampling(position_sparse, n_final_sample - n_initial_sample)
  dense_sample = knn_smoothing(position_sparse, new_pts, k=10)
  return dense_sample


# def acc_vel_diff(positions):
#   current_velocity = positions[-1] - positions[-2]
#   pre_velocity = positions[-2] - positions[-3]
#   current_acc = current_velocity - pre_velocity

#   return current_acc, current_velocity

def rollout(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        position_dense: torch.tensor,
        particle_types: torch.tensor,
        particle_types_dense: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        n_particles_per_example_dense: int,
        nsteps: int,
        device: torch.device,
        line: list):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  # breakpoint()
  N = 2  
  rollout_length = position.shape[1] - N * 6

  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 
  initial_positions_dense = position_dense[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]
  ground_truth_positions_dense = position_dense[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]

  current_positions = initial_positions
  predictions = []
  predictions_dense = []
  velocitys = []
  accs = []
  cos_acc=[]
  minus_times = []
  mpm_start_step = []
  # predict_grids = []
  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity

  nsteps = len(ground_truth_indices)*N
  # for step in tqdm(range(nsteps), total=nsteps):
  count_nan = 0
  step = 0
  # start_time = time.time()
  time_step_gns = 0
  time_step_mpm = 0
  mpm = True
  while step < nsteps:
    # tensorrt
    # latency = compute_latency_ms_pytorch(simulator, current_positions, n_particles_per_example, particle_types, material_property, iterations=rollout_length)
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc = simulator.predict_positions(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        # interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        # interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        
        predictions.append(interp_pos) 
        # velocitys.append(interp_vel)
        # accs.append(interp_acc) 
    
    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)

    # current_velocity = next_velocity
    # current_acc = next_acc
    step += N
    time_step_gns += 1

    if len(accs) > 1:      
      similarity = torch.nn.functional.cosine_similarity(
          accs[-1], accs[-2], dim=1
      )
      cos_acc.append(torch.mean(similarity).item())
    
    if len(cos_acc)>10 and mpm:
      # breakpoint()
      cos_seq = torch.tensor(cos_acc)
      cos_chunk = cos_seq[-10:]
      mean_cos = cos_chunk.mean().item()
      if mean_cos < 0.8: 
    # if step > 0:
        # breakpoint()
        # step -= N
        # time_step_gns -= 1
        # predictions.pop()
        # current_position = predictions[-1]
        print('Go taichi simulate')
        mpm_start_step.append(step)
        next_velocity = next_position - interp_pos
        mpm_step = 50 
        if mpm_step < 12:
           mpm_step = 12
        future_step50, future_velocities = run_simulation_mpm_thershold(next_position, next_velocity, step=mpm_step) #run_simulation_mpm_thershold(next_position, next_velocity, particle_types, step=mpm_step) #particle_types, line
        if future_step50 == None:
           return None, None, None, None, None, None
           mpm = False
           continue
        # minus_times.append(minus_time)
        future_tensors = [torch.tensor(arr, device=device) for arr in future_step50]
        predictions += future_tensors
        future_step50 = np.array(future_step50)
        end_idx = future_step50.shape[0]
        indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
        # indices = torch.arange(end_idx - 1, end_idx - 1 - 1 * INPUT_SEQUENCE_LENGTH, step=-1)
        indices = indices.flip(0).cpu().numpy()
        current_positions = torch.tensor(future_step50[indices]).cuda().permute(1,0,2)
        cos_acc = []
        accs = []
        step += mpm_step
        if step > nsteps:
          time_step_mpm += (mpm_step-(step-nsteps))
        else:
          time_step_mpm += mpm_step
        # current_velocity = torch.tensor(future_step50[-1] - future_step50[-2]).cuda()
        # pre_velocity = torch.tensor(future_step50[-2] - future_step50[-3]).cuda()
        # current_acc = current_velocity - pre_velocity

        continue
    # next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
    # next_position_dense = torch.tensor(next_position_dense, device=device)
    # predictions_dense.append(next_position_dense)
    
    # predict_grid = particle_to_grid(next_position_dense, n_particles_per_example_dense)
    # predict_grids.append(predict_grid)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)
  
  # breakpoint()
  # end_time = time.time()
  # rollout_time = end_time-start_time
  # rollout_time = (time_step_gns * 0.8096 + time_step_mpm * 0.92228) / 1000 #water2d
  # rollout_time = (time_step_gns * 0.6574 + time_step_mpm * 1.8790) / 1000 #sand2d
  # rollout_time = (time_step_gns * 0.7117  + time_step_mpm * 2.44766) / 1000 #sandramps
  # rollout_time = (time_step_gns * 0.7454  + time_step_mpm * 1.9572) / 1000 #waterramps
  rollout_time = (time_step_gns * 0.8209  + time_step_mpm * 107.864) / 1000 #watersand
  # Predictions with shape (time, nnodes, dim)
  ground_truth_positions_dense = ground_truth_positions_dense[:, :rollout_length]
  # predictions_dense = torch.stack(predictions_dense)
  predictions = predictions[:rollout_length]
  predictions = torch.stack(predictions)
  # breakpoint()
  # predict_grids = torch.stack(predict_grids)
  
  ground_truth_positions_dense = ground_truth_positions_dense.permute(1, 0, 2)
  
  if torch.isnan(predictions).any(): 
    loss = torch.tensor(0.0)
    count_nan += 1
  else:
    ground_truth_positions_dense = torch.clip(ground_truth_positions_dense, 0.1, 0.9)
    predictions = torch.clip(predictions, 0.1, 0.9)
    grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions_dense)
    # breakpoint()
    # loss = (predictions - ground_truth_positions) ** 2
    # rmse
    # loss = ((predictions - ground_truth_positions_dense) ** 2)/((ground_truth_positions_dense) ** 2)
    # breakpoint()
    epsilon = 1e-6 
    loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)
    # loss_values = loss.mean(dim=(1, 2))
    # loss_array = loss_values.cpu().numpy()
    # np.save("loss_values_gnsmpm_water2d.npy", loss_array)
  # breakpoint()
  # if len(minus_times) >= 2:
  #   time_to_subtract = sum(minus_times[1:])  # 
  #   rollout_time = rollout_time - time_to_subtract
  # else:
  #     rollout_time = rollout_time

  output_dict = {
      'initial_positions': initial_positions_dense.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions_dense.cpu().numpy(),
      'particle_types': particle_types_dense.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, rollout_time, count_nan

def rollout_mpm(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        position_dense: torch.tensor,
        particle_types: torch.tensor,
        particle_types_dense: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        n_particles_per_example_dense: int,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  # reducer = ParticleReducerTorchLite(target_particles=int(n_particles_per_example))
  N = 2  
  rollout_length = position.shape[1] - N * 6

  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 
  initial_positions_dense = position_dense[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]
  ground_truth_positions_dense = position_dense[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]
  current_positions = initial_positions
  predictions = []
  predictions_vis = []
  # predictions_vis.append(position[:, initial_indices[-1]-1]) 
  # predictions_vis.append(position[:, initial_indices[-1]]) 
  predictions_dense = []
  velocitys = []
  accs = []
  cos_acc = []
  count_nan = 0
  # predict_grids = []
  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity

  nsteps = len(ground_truth_indices)*N
  step = 0
  start_time = time.time()
  # for step in tqdm(range(nsteps), total=nsteps):
  while step < nsteps:
    # Get next position with shape (nnodes, dim)
    # breakpoint()
    next_position, next_velocity, next_acc = simulator.predict_positions(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    next_position = torch.clamp(next_position, min=0.1, max=0.9)
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        # interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        # interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        
        predictions.append(interp_pos) 
        predictions_vis.append(interp_pos)
        # interp_acc, interp_vel = acc_vel_diff(predictions_vis)
        # velocitys.append(interp_vel)
        # accs.append(interp_acc) 
    
    predictions_vis.append(next_position)
    predictions.append(next_position)
    
    # next_velocity, next_acc = acc_vel_diff(predictions_vis)
    velocitys.append(next_velocity)
    accs.append(next_acc)

    # current_velocity = next_velocity
    # current_acc = next_acc

    # print(next_position)
    # plot_figure(ground_truth_positions_dense[:, step].cpu().numpy(), next_position.cpu().numpy(), step)
    step += N

    # next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
    # next_position_dense = torch.tensor(next_position_dense, device=device)
    # predictions_dense.append(next_position_dense)
    if len(accs) > 1:      
      similarity = torch.nn.functional.cosine_similarity(
          accs[-1], accs[-2], dim=1
      )
      cos_acc.append(torch.mean(similarity).item())

    if len(cos_acc)>10: # go on mpm
      cos_seq = torch.tensor(cos_acc)
      cos_chunk = cos_seq[-10:]
      mean_cos = cos_chunk.mean().item()
      if mean_cos < 0.9: 
        # breakpoint()
        # check number of next_position
        print('Go taichi simulate')
        next_velocity = next_position - interp_pos
        next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
        next_position_dense = torch.tensor(next_position_dense, device=device)
        future_step50, future_velocities, future_step50_sparse = run_simulation_mpm(next_position, next_velocity, next_position_dense)
        # concat
        future_tensors = [torch.tensor(arr, device=device) for arr in future_step50]
        future_tensors_sparse = [torch.tensor(arr, device=device) for arr in future_step50_sparse]
        predictions += future_tensors
        predictions_vis += future_tensors_sparse
        future_step50 = np.array(future_step50)
        future_step50_sparse = np.array(future_step50_sparse)
        end_idx = future_step50.shape[0]
        indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
        indices = indices.flip(0).cpu().numpy()
        current_positions = torch.tensor(future_step50_sparse[indices]).cuda().permute(1,0,2)
        # sparse_positions = []
        # for i in range(len(current_positions_dense)):
        #   new_pos = sample_points(current_positions_dense[i], int(n_particles_per_example))
        #   sparse_positions.append(new_pos)    
        #   # plot_figure(current_positions_dense[i].cpu().numpy(), new_pos.cpu().numpy(), step)
        #   # plot_figure(ground_truth_positions_dense[:, step].cpu().numpy(),new_pos.cpu().numpy(), step)
        # current_positions = torch.stack(sparse_positions).permute(1,0,2)
        # vis
        # breakpoint()
        # for i in range(50):
        #   vis_step = step + i + 1
        #   plot_figure(ground_truth_positions_dense[:, vis_step].cpu().numpy(), future_step50[i], vis_step)

        cos_acc = []
        accs = []
        step += 50
        # current_velocity = torch.tensor(future_step50_sparse[-1] - future_step50_sparse[-2]).cuda()
        # pre_velocity = torch.tensor(future_step50_sparse[-2] - future_step50_sparse[-3]).cuda()
        # current_acc = current_velocity - pre_velocity

        continue
      # predictions_dense.append(next_position_dense)
      # start_time = time.time()
      # future_step50, future_velocities = run_simulation_mpm(next_position, next_velocity, next_position_dense)
      # print(f"runing time: {(time.time()-start_time):.4f} second")
    # breakpoint()
    # predict_grid = particle_to_grid(next_position_dense, n_particles_per_example_dense)
    # predict_grids.append(predict_grid)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    # breakpoint()
    # end_idx = future_step50.shape[0]
    # indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
    # indices = indices.flip(0).cpu().numpy()

    # current_positions_dense = torch.tensor(future_step50[indices]).cuda()
    # current_vel_dense =  torch.tensor(future_velocities[indices]).cuda()

    # sparse_positions = []
    # for i in range(len(current_positions_dense)):
    #   breakpoint()
    #   start_time = time.time()
    #   # new_pos, new_vel = reducer.reduce_system(current_positions_dense[i], current_vel_dense[i])
    #   # new_pos = reducer.reduce_system(current_positions_dense[i])
    #   new_pos = sample_points(current_positions_dense[i], int(n_particles_per_example))
    #   print(f"runing time: {(time.time()-start_time):.4f} second")
    #   sparse_positions.append(new_pos)    
    # # current_positions = torch.tensor(np.array(sparse_positions)).permute(1,0,2).cuda()
    # current_positions = torch.stack(sparse_positions).permute(1,0,2)
    # breakpoint()
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)
  
  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  end_time = time.time()
  rollout_time = end_time-start_time
  # breakpoint()
  ground_truth_positions_dense = ground_truth_positions_dense[:, :rollout_length]
  predictions = predictions[:rollout_length]
  # predictions_vis = predictions_vis[2:rollout_length + 2]
  predictions_vis = torch.stack(predictions_vis)
  # predict_grids = torch.stack(predict_grids)
  ground_truth_positions_dense = ground_truth_positions_dense.permute(1, 0, 2)

  grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions_dense)
  
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  # loss = ((predictions_dense - ground_truth_positions_dense) ** 2)/((ground_truth_positions_dense) ** 2)
  epsilon = 1e-6 
  loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)
  
  if torch.isnan(loss).any(): 
    loss = torch.tensor(0.0)
    count_nan += 1

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'initial_positions_dense': initial_positions_dense.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions_vis.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions_dense.cpu().numpy(),
      'particle_types': particle_types_dense.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, rollout_time, count_nan

def rollout_with_gt(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  N = 1  
  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]

  current_positions = initial_positions
  predictions = []
  velocitys = []
  accs = []
  edge_indexs = []

  nsteps = len(ground_truth_indices)
  for step in tqdm(range(nsteps), total=nsteps):
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc, edge_index = simulator.predict_positions_index(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        predictions.append(interp_pos) 

    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)
    edge_indexs.append(edge_index)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], ground_truth_positions[:, step:step+1, :]], dim=1)

  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  ground_truth_positions = ground_truth_positions[:, :len(predictions)]
  predictions = torch.stack(predictions)
  ground_truth_positions = ground_truth_positions.permute(1, 0, 2)

  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, edge_indexs


def rollout_N(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  N = 10
  total_predictions = torch.zeros_like(position)
  total_predictions[:, :INPUT_SEQUENCE_LENGTH] = position[:, :INPUT_SEQUENCE_LENGTH]

  for offset in range(N):
      initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N + offset
      initial_indices = initial_indices[initial_indices < position.shape[1]]  
      
      if len(initial_indices) < INPUT_SEQUENCE_LENGTH:
          continue
      
      current_positions = position[:, initial_indices]
      ground_truth_start = initial_indices[-1] + 1
      gt_indices = np.arange(ground_truth_start, position.shape[1], N)
      
      for step in range(len(gt_indices)):
        # Get next position with shape (nnodes, dim)
        next_position, acc = simulator.predict_positions(
            current_positions,
            nparticles_per_example=[n_particles_per_example],
            particle_types=particle_types,
            material_property=material_property
        )

        # Update kinematic particles from prescribed trajectory.
        kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
        next_position_ground_truth = position[:, gt_indices[step]].to(device)
        kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
        next_position = torch.where(
            kinematic_mask, next_position_ground_truth, next_position)
        # predictions.append(next_position)
        target_idx = gt_indices[step]
        if target_idx < total_predictions.shape[1]:
            total_predictions[:, target_idx] = next_position.detach().cpu()
        # Shift `current_positions`, removing the oldest position in the sequence
        # and appending the next position at the end.
        current_positions = torch.cat(
            [current_positions[:, 1:], next_position[:, None, :]], dim=1)

  # Predictions with shape (time, nnodes, dim)
  predictions = total_predictions[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  ground_truth_positions = position[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  initial_positions = position[:, :N*INPUT_SEQUENCE_LENGTH]
  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss



def predict(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'
  breakpoint()
  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz", path_dense=f"{FLAGS.data_path_dense}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  total_time = []
  count_nan_total = 0
  # import pdb; pdb.set_trace()
  # start_time = time.time()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      positions_dense = features[3].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      particle_type_dense = features[4].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      n_particles_per_example_dense = features[5]
      line = features[6]
      example_rollout, loss, _, _, rollout_time, count_nan = rollout(simulator,
                                      positions,
                                      positions_dense,
                                      particle_type,
                                      particle_type_dense,
                                      material_property,
                                      n_particles_per_example,
                                      n_particles_per_example_dense,
                                      nsteps,
                                      device,
                                      line=line)
      if example_rollout == None:
         continue
      count_nan_total += count_nan
      total_time.append(rollout_time)
      print("rollout_time: {}".format(rollout_time))
      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))
      loss_txt = FLAGS.output_path + "loss_log.txt"
      with open(loss_txt, "a") as f:
        f.write(f"{loss.mean().item()}\n")
      time_txt = FLAGS.output_path + "time_log.txt"
      with open(time_txt, "a") as f:
        f.write(f"{rollout_time}\n")

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # breakpoint()
      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  # end_time = time.time()
  # execution_time = end_time - start_time
  # print(f"Total runing time: {execution_time:.4f} second")
  print("Nan times: {}".format(count_nan_total))
  print("Average_time: {}".format(sum(total_time)/len(total_time)))
  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_acc_cosine(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'
  breakpoint()
  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz", path_dense=f"{FLAGS.data_path_dense}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      positions_dense = features[3].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      particle_type_dense = features[4].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      n_particles_per_example_dense = features[5]
      # Predict example rollout
      
      example_rollout, loss, velocity ,acc= rollout(simulator,
                                      positions,
                                      positions_dense,
                                      particle_type,
                                      particle_type_dense,
                                      material_property,
                                      n_particles_per_example,
                                      n_particles_per_example_dense,
                                      nsteps,
                                      device,
                                      None)
      
      # breakpoint()
      # plot loss-acc-curve
      cos_acc=[]
      for t in range(1, len(acc)):
        similarity = torch.nn.functional.cosine_similarity(
            acc[t], acc[t-1], dim=1
        )
        cos_acc.append(torch.mean(similarity).item())
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot((normalized_loss[1:]).numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(cos_acc, label='Acceleration Cosine Similarity', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_acc_rollout_xy/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_acc_rollout_xy/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()

      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # correlation
      # breakpoint()
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(cos_acc, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))

      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean Cosine Accuracy (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_acc_rollout_xy/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()
      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_acc/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")


      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_velocity_divergence_KD(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'

  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, velocities,accs, edge_indexs = rollout_with_gt(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      
      frame_losses = []
      num_frames, num_particles, dim = example_rollout['predicted_rollout'].shape
      from scipy.spatial import KDTree
      for t in range(num_frames):
        pos = example_rollout['predicted_rollout'][t]
        vel = velocities[t].cpu().numpy()
        
        # KDTree
        tree = KDTree(pos)
        divergences = np.zeros(num_particles)
        
        # 
        for i in range(num_particles):
            # k（）
            _, neighbor_indices = tree.query(pos[i], k=5) #k_neighbors=5
            neighbor_pos = pos[neighbor_indices]
            neighbor_vel = vel[neighbor_indices]
            # 
            relative_pos = neighbor_pos - pos[i]
            relative_vel = neighbor_vel - vel[i]
            # 
            if len(relative_pos) < 2:
                grad_u = np.zeros(dim)
                grad_v = np.zeros(dim)
            else:
                #  (u, v)
                grad_u, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 0], rcond=None)
                grad_v, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 1], rcond=None)
            
            divergences[i] = grad_u[0] + grad_v[1]
        
        # # （）
        # mass_loss = np.mean(divergences ** 2)
        # frame_losses['mass_conservation'].append(mass_loss)
        div_loss = np.mean(divergences ** 2)
        frame_losses.append(div_loss)
      
      # breakpoint()
      frame_losses = np.array(frame_losses)
      normalized_frame_loss = (frame_losses - frame_losses.min()) / (frame_losses.max() - frame_losses.min())
      # normalize
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      # plot loss-acc-curve
      # cos_acc=[]
      # for t in range(1, len(acc)):
      #   similarity = torch.nn.functional.cosine_similarity(
      #       acc[t], acc[t-1], dim=1
      #   )
      #   cos_acc.append(torch.mean(similarity).item())
      # loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      # normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot(normalized_loss.numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(normalized_frame_loss, label='velocity divergence loss', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_velocity_divergence_kd/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()

      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(normalized_frame_loss, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))

      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean velocity divergence KD (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()
      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_velocity_divergence_edge(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'

  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, velocities,accs, edge_indexs = rollout_with_gt(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      
      frame_losses = []
      num_frames, num_particles, dim = example_rollout['predicted_rollout'].shape
      from scipy.spatial import KDTree
      for t in range(num_frames):
        pos = example_rollout['predicted_rollout'][t]
        vel = velocities[t].cpu().numpy()
        edge_index = edge_indexs[t].cpu().numpy() 

        divergences = np.zeros(num_particles)
        
        # 
        # breakpoint()
        for i in range(num_particles):
            # k（）
            neighbor_mask = edge_index[0] == i
            neighbor_indices = edge_index[1][neighbor_mask]

            if len(neighbor_indices) < dim + 1:
                divergences[i] = 0.0
                continue

            relative_pos = pos[neighbor_indices] - pos[i]      # (k, dim)
            relative_vel = vel[neighbor_indices] - vel[i]      # (k, dim)
            # 
            if len(relative_pos) < 2:
                grad_u = np.zeros(dim)
                grad_v = np.zeros(dim)
            else:
                #  (u, v)
                grad_u, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 0], rcond=None)
                grad_v, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 1], rcond=None)
          
            divergences[i] = grad_u[0] + grad_v[1]
        
        div_loss = np.mean(divergences ** 2)
        frame_losses.append(div_loss)
      
      # breakpoint()
      frame_losses = np.array(frame_losses)
      normalized_frame_loss = (frame_losses - frame_losses.min()) / (frame_losses.max() - frame_losses.min())
      # normalize
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      # plot loss-acc-curve
      # cos_acc=[]
      # for t in range(1, len(acc)):
      #   similarity = torch.nn.functional.cosine_similarity(
      #       acc[t], acc[t-1], dim=1
      #   )
      #   cos_acc.append(torch.mean(similarity).item())
      # loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      # normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot(normalized_loss.numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(normalized_frame_loss, label='velocity divergence loss', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_velocity_divergence_edge/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()
      
      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(normalized_frame_loss, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))
      
      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean velocity divergence edge (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()

      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")

      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def optimizer_to(optim, device):
  for param in optim.state.values():
    # Not sure there are any global tensors in the state dict
    if isinstance(param, torch.Tensor):
      param.data = param.data.to(device)
      if param._grad is not None:
        param._grad.data = param._grad.data.to(device)
    elif isinstance(param, dict):
      for subparam in param.values():
        if isinstance(subparam, torch.Tensor):
          subparam.data = subparam.data.to(device)
          if subparam._grad is not None:
            subparam._grad.data = subparam._grad.data.to(device)

def acceleration_loss(pred_acc, target_acc, non_kinematic_mask):
  """
  Compute the loss between predicted and target accelerations.

  Args:
    pred_acc: Predicted accelerations.
    target_acc: Target accelerations.
    non_kinematic_mask: Mask for kinematic particles.
  """
  loss = (pred_acc - target_acc) ** 2
  loss = loss.sum(dim=-1)
  num_non_kinematic = non_kinematic_mask.sum()
  loss = torch.where(non_kinematic_mask.bool(),
                    loss, torch.zeros_like(loss))
  loss = loss.sum() / num_non_kinematic
  return loss

def save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer,
                                train_loss, valid_loss, train_loss_hist, valid_loss_hist):
  """Save model state
  
  Args:
    rank: local rank
    device: torch device type
    simulator: Trained simulator if not will undergo training.
    flags: flags
    step: step
    epoch: epoch
    optimizer: optimizer
    train_loss: training loss at current step
    valid_loss: validation loss at current step
    train_loss_hist: training loss history at each epoch
    valid_loss_hist: validation loss history at each epoch
  """
  if rank == 0 or device == torch.device("cpu"):
      if device == torch.device("cpu"):
          simulator.save(flags["model_path"] + 'model-' + str(step) + '.pt')
      else:
          simulator.module.save(flags["model_path"] + 'model-' + str(step) + '.pt')

      train_state = dict(optimizer_state=optimizer.state_dict(),
                          global_train_state={
                            "step": step, 
                            "epoch": epoch,
                            "train_loss": train_loss,
                            "valid_loss": valid_loss
                            },
                          loss_history={"train": train_loss_hist, "valid": valid_loss_hist}
                          )
      torch.save(train_state, f'{flags["model_path"]}train_state-{step}.pt')
      
def train(rank, flags, world_size, device):
  """Train the model.

  Args:
    rank: local rank
    world_size: total number of ranks
    device: torch device type
  """
  if device == torch.device("cuda"):
    distribute.setup(rank, world_size, device)
    device_id = rank
  else:
    device_id = device

  # import pdb; pdb.set_trace()
  # Read metadata
  metadata = reading_utils.read_metadata(flags["data_path"], "train")

  # Get simulator and optimizer
  if device == torch.device("cuda"):
    serial_simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], rank)
    simulator = DDP(serial_simulator.to(rank), device_ids=[rank], output_device=rank)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"]*world_size)
  else:
    simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], device)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"] * world_size)

  # Initialize training state
  step = 0
  epoch = 0
  steps_per_epoch = 0

  valid_loss = None
  epoch_train_loss = 0
  epoch_valid_loss = None

  train_loss_hist = []
  valid_loss_hist = []

  # If model_path does exist and model_file and train_state_file exist continue training.
  if flags["model_file"] is not None:

    if flags["model_file"] == "latest" and flags["train_state_file"] == "latest":
      # find the latest model, assumes model and train_state files are in step.
      fnames = glob.glob(f'{flags["model_path"]}*model*pt')
      max_model_number = 0
      expr = re.compile(".*model-(\d+).pt")
      for fname in fnames:
        model_num = int(expr.search(fname).groups()[0])
        if model_num > max_model_number:
          max_model_number = model_num
      # reset names to point to the latest.
      flags["model_file"] = f"model-{max_model_number}.pt"
      flags["train_state_file"] = f"train_state-{max_model_number}.pt"

    if os.path.exists(flags["model_path"] + flags["model_file"]) and os.path.exists(flags["model_path"] + flags["train_state_file"]):
      # load model
      if device == torch.device("cuda"):
        simulator.module.load(flags["model_path"] + flags["model_file"])
      else:
        simulator.load(flags["model_path"] + flags["model_file"])

      # load train state
      train_state = torch.load(flags["model_path"] + flags["train_state_file"])
      
      # set optimizer state
      optimizer = torch.optim.Adam(
        simulator.module.parameters() if device == torch.device("cuda") else simulator.parameters())
      optimizer.load_state_dict(train_state["optimizer_state"])
      optimizer_to(optimizer, device_id)
      
      # set global train state
      step = train_state["global_train_state"]["step"]
      epoch = train_state["global_train_state"]["epoch"]
      train_loss_hist = train_state["loss_history"]["train"]
      valid_loss_hist = train_state["loss_history"]["valid"]

    else:
      msg = f'Specified model_file {flags["model_path"] + flags["model_file"]} and train_state_file {flags["model_path"] + flags["train_state_file"]} not found.'
      raise FileNotFoundError(msg)

  simulator.train()
  simulator.to(device_id)

  # Get data loader
  get_data_loader = (
    distribute.get_data_distributed_dataloader_by_samples
    if device == torch.device("cuda")
    else data_loader.get_data_loader_by_samples
  )

  # Load training data
  # import pdb; pdb.set_trace()
  # train_npz = f'{flags["data_path"]}train.npz'
  # valid_npz = f'{flags["data_path"]}valid.npz'
  dl = get_data_loader(
      path=f'{flags["data_path"]}train.npz',
      input_length_sequence=INPUT_SEQUENCE_LENGTH,
      batch_size=flags["batch_size"],
  )
  n_features = len(dl.dataset._data[0])

  # Load validation data
  if flags["validation_interval"] is not None:
      dl_valid = get_data_loader(
          path=f'{flags["data_path"]}valid.npz',
          input_length_sequence=INPUT_SEQUENCE_LENGTH,
          batch_size=flags["batch_size"],
      )
      if len(dl_valid.dataset._data[0]) != n_features:
          raise ValueError(
              f"`n_features` of `valid.npz` and `train.npz` should be the same"
          )

  print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")

  try:
    while step < flags["ntraining_steps"]:
      if device == torch.device("cuda"):
        torch.distributed.barrier()

      for example in dl:  
        # import pdb; pdb.set_trace()
        steps_per_epoch += 1
        # ((position, particle_type, material_property, n_particles_per_example), labels) are in dl
        position = example[0][0].to(device_id)
        particle_type = example[0][1].to(device_id)
        if n_features == 3:  # if dl includes material_property
          material_property = example[0][2].to(device_id)
          n_particles_per_example = example[0][3].to(device_id)
        elif n_features == 2:
          n_particles_per_example = example[0][2].to(device_id)
        else:
          raise NotImplementedError
        labels = example[1].to(device_id)

        n_particles_per_example.to(device_id)
        labels.to(device_id)

        # TODO (jpv): Move noise addition to data_loader
        # Sample the noise to add to the inputs to the model during training.
        sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(position, noise_std_last_step=flags["noise_std"]).to(device_id)
        non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
        sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

        # Get the predictions and target accelerations
        device_or_rank = rank if device == torch.device("cuda") else device
        pred_acc, target_acc = (simulator.module.predict_accelerations if device == torch.device("cuda") else simulator.predict_accelerations)(
            next_positions=labels.to(device_or_rank),
            position_sequence_noise=sampled_noise.to(device_or_rank),
            position_sequence=position.to(device_or_rank),
            nparticles_per_example=n_particles_per_example.to(device_or_rank),
            particle_types=particle_type.to(device_or_rank),
            material_property=material_property.to(device_or_rank) if n_features == 3 else None
        )
        
        # Validation
        if flags["validation_interval"] is not None:
          sampled_valid_example = next(iter(dl_valid))
          if step > 0 and step % flags["validation_interval"] == 0:
              valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
              print(f"Validation loss at {step}: {valid_loss.item()}")

        # Calculate the loss and mask out loss on kinematic particles
        loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

        train_loss = loss.item()
        epoch_train_loss += train_loss

        # Computes the gradient of loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update learning rate
        lr_new = flags["lr_init"] * (flags["lr_decay"] ** (step/flags["lr_decay_steps"])) * world_size
        for param in optimizer.param_groups:
          param['lr'] = lr_new
     
        print(f'rank = {rank}, epoch = {epoch}, step = {step}/{flags["ntraining_steps"]}, loss = {train_loss}', flush=True)

        # Save model state
        if rank == 0 or device == torch.device("cpu"):
          if step % flags["nsave_steps"] == 0:
            save_model_and_train_state(rank, device, simulator, flags, step, epoch, 
                                       optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

        step += 1
        if step >= flags["ntraining_steps"]:
            break

      # Epoch level statistics
      # Training loss at epoch
      epoch_train_loss /= steps_per_epoch
      epoch_train_loss = torch.tensor([epoch_train_loss]).to(device_id)
      if device == torch.device("cuda"):
        torch.distributed.reduce(epoch_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        epoch_train_loss /= world_size

      train_loss_hist.append((epoch, epoch_train_loss.item()))

      # Validation loss at epoch
      if flags["validation_interval"] is not None:
        sampled_valid_example = next(iter(dl_valid))
        epoch_valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
        if device == torch.device("cuda"):
          torch.distributed.reduce(epoch_valid_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
          epoch_valid_loss /= world_size

        valid_loss_hist.append((epoch, epoch_valid_loss.item()))

      # Print epoch statistics
      if rank == 0 or device == torch.device("cpu"):
        print(f'Epoch {epoch}, training loss: {epoch_train_loss.item()}')
        if flags["validation_interval"] is not None:
          print(f'Epoch {epoch}, validation loss: {epoch_valid_loss.item()}')
      
      # Reset epoch training loss
      epoch_train_loss = 0
      if steps_per_epoch >= len(dl):
        epoch += 1
      steps_per_epoch = 0
      
      if step >= flags["ntraining_steps"]:
        break 
      
  except KeyboardInterrupt:
    pass

  # Save model state on keyboard interrupt
  save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

  if torch.cuda.is_available():
    distribute.cleanup()


def _get_simulator(
        metadata: json,
        acc_noise_std: float,
        vel_noise_std: float,
        device: torch.device) -> learned_simulator.LearnedSimulator:
  """Instantiates the simulator.

  Args:
    metadata: JSON object with metadata.
    acc_noise_std: Acceleration noise std deviation.
    vel_noise_std: Velocity noise std deviation.
    device: PyTorch device 'cpu' or 'cuda'.
  """
  # import pdb; pdb.set_trace()
  # Normalization stats
  normalization_stats = {
      'acceleration': {
          'mean': torch.FloatTensor(metadata['acc_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['acc_std'])**2 +
                            acc_noise_std**2).to(device),
      },
      'velocity': {
          'mean': torch.FloatTensor(metadata['vel_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['vel_std'])**2 +
                            vel_noise_std**2).to(device),
      },
  }

  # Get necessary parameters for loading simulator.
  if "nnode_in" in metadata and "nedge_in" in metadata:
    nnode_in = metadata['nnode_in']
    nedge_in = metadata['nedge_in']
  else:
    # Given that there is no additional node feature (e.g., material_property) except for:
    # (position (dim), velocity (dim*6), particle_type (16)),
    nnode_in = 37 if metadata['dim'] == 3 else 30
    nedge_in = metadata['dim'] + 1

  # Init simulator.
  simulator = learned_simulator.LearnedSimulator(
      particle_dimensions=metadata['dim'],
      nnode_in=nnode_in,
      nedge_in=nedge_in,
      latent_dim=128,
      nmessage_passing_steps=10,
      nmlp_layers=2,
      mlp_hidden_dim=128,
      connectivity_radius=metadata['default_connectivity_radius'],
      boundaries=np.array(metadata['bounds']),
      normalization_stats=normalization_stats,
      nparticle_types=NUM_PARTICLE_TYPES,
      particle_type_embedding_size=16,
      boundary_clamp_limit=metadata["boundary_augment"] if "boundary_augment" in metadata else 1.0,
      device=device)

  return simulator

def validation(
        simulator,
        example,
        n_features,
        flags,
        rank,
        device_id):

  position = example[0][0].to(device_id)
  particle_type = example[0][1].to(device_id)
  if n_features == 3:  # if dl includes material_property
    material_property = example[0][2].to(device_id)
    n_particles_per_example = example[0][3].to(device_id)
  elif n_features == 2:
    n_particles_per_example = example[0][2].to(device_id)
  else:
    raise NotImplementedError
  labels = example[1].to(device_id)

  # Sample the noise to add to the inputs.
  sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
    position, noise_std_last_step=flags["noise_std"]).to(device_id)
  non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
  sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

  # Do evaluation for the validation data
  device_or_rank = rank if isinstance(device_id, int) else device_id
  # Select the appropriate prediction function
  predict_accelerations = simulator.module.predict_accelerations if isinstance(device_id, int) else simulator.predict_accelerations
  # Get the predictions and target accelerations
  with torch.no_grad():
      pred_acc, target_acc = predict_accelerations(
          next_positions=labels.to(device_or_rank),
          position_sequence_noise=sampled_noise.to(device_or_rank),
          position_sequence=position.to(device_or_rank),
          nparticles_per_example=n_particles_per_example.to(device_or_rank),
          particle_types=particle_type.to(device_or_rank),
          material_property=material_property.to(device_or_rank) if n_features == 3 else None
      )

  # Compute loss
  loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

  return loss


def main(_):
  """Train or evaluates the model.

  """
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  if device == torch.device('cuda'):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29502"

  myflags = reading_utils.flags_to_dict(FLAGS)

  if FLAGS.mode == 'train':
    # If model_path does not exist create new directory.
    if not os.path.exists(FLAGS.model_path):
      os.makedirs(FLAGS.model_path)
    # import pdb; pdb.set_trace()
    # Train on gpu 
    if device == torch.device('cuda'):
      available_gpus = torch.cuda.device_count()
      print(f"Available GPUs = {available_gpus}")

      # Set the number of GPUs based on availability and the specified number
      if FLAGS.n_gpus is None or FLAGS.n_gpus > available_gpus:
        world_size = available_gpus
        if FLAGS.n_gpus is not None:
          print(f"Warning: The number of GPUs specified ({FLAGS.n_gpus}) exceeds the available GPUs ({available_gpus})")
      else:
        world_size = FLAGS.n_gpus

      # Print the status of GPU usage
      print(f"Using {world_size}/{available_gpus} GPUs")

      # Spawn training to GPUs
      # import pdb; pdb.set_trace()
      # os.environ["MASTER_PORT"] = "29501"
      distribute.spawn_train(train, myflags, world_size, device)

    # Train on cpu  
    else:
      rank = None
      world_size = 1
      train(rank, myflags, world_size, device)

  elif FLAGS.mode in ['valid', 'rollout']:
    # Set device
    world_size = torch.cuda.device_count()
    if FLAGS.cuda_device_number is not None and torch.cuda.is_available():
      device = torch.device(f'cuda:{int(FLAGS.cuda_device_number)}')
    #test code
    print(f"device is {device} world size is {world_size}")
    predict(device)


if __name__ == '__main__':
  app.run(main)
