import collections
import json
import os
import pickle
import glob
import re
import sys
import statistics

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

from absl import flags
from absl import app

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from gns import learned_simulator
from gns import noise_utils
from gns import reading_utils
from gns import data_loader
from gns import distribute

import matplotlib.pyplot as plt
from p2g_utils import run_simulation
import time
from scipy.stats import spearmanr

flags.DEFINE_enum(
    'mode', 'train', ['train', 'valid', 'rollout'],
    help='Train model, validation or rollout evaluation.')
flags.DEFINE_integer('batch_size', 1, help='The batch size.')
flags.DEFINE_float('noise_std', 6.7e-4, help='The std deviation of the noise.')
flags.DEFINE_string('data_path', None, help='The dataset directory.')
flags.DEFINE_string('model_path', 'models/', help=('The path for saving checkpoints of the model.'))
flags.DEFINE_string('output_path', 'rollouts/', help='The path for saving outputs (e.g. rollouts).')
flags.DEFINE_string('output_filename', 'rollout', help='Base name for saving the rollout')
flags.DEFINE_string('model_file', None, help=('Model filename (.pt) to resume from. Can also use "latest" to default to newest file.'))
flags.DEFINE_string('train_state_file', 'train_state.pt', help=('Train state filename (.pt) to resume from. Can also use "latest" to default to newest file.'))

flags.DEFINE_integer('ntraining_steps', int(2E7), help='Number of training steps.')
flags.DEFINE_integer('validation_interval', None, help='Validation interval. Set `None` if validation loss is not needed')
flags.DEFINE_integer('nsave_steps', int(100000), help='Number of steps at which to save the model.')

# Learning rate parameters
flags.DEFINE_float('lr_init', 1e-4, help='Initial learning rate.')
flags.DEFINE_float('lr_decay', 0.1, help='Learning rate decay.')
flags.DEFINE_integer('lr_decay_steps', int(5e6), help='Learning rate decay steps.')

flags.DEFINE_integer("cuda_device_number", None, help="CUDA device (zero indexed), default is None so default CUDA device will be used.")
flags.DEFINE_integer("n_gpus", 1, help="The number of GPUs to utilize for training.")

FLAGS = flags.FLAGS

Stats = collections.namedtuple('Stats', ['mean', 'std'])

INPUT_SEQUENCE_LENGTH = 6  # So we can calculate the last 5 velocities.
NUM_PARTICLE_TYPES = 9
KINEMATIC_PARTICLE_ID = 3

def rollout(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  N = 1
  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices] #  not interplote
  gt_initial = initial_indices[-1] + 1 # do interplote
  ground_truth_positions = position[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]

  current_positions = initial_positions
  predictions = []
  velocitys = []
  accs = []

  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity
  predict_grids = []

  nsteps = len(ground_truth_indices)
  start_time = time.time()
  for step in tqdm(range(nsteps), total=nsteps):
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc = simulator.predict_positions(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # breakpoint()
    # next_position = torch.clip(next_position, 0.1, 0.9)
    if torch.isnan(next_position).any():
      # breakpoint()
      print('NAN WARNING')
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        # interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        # interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        
        predictions.append(interp_pos) 
        # velocitys.append(interp_vel)
        # accs.append(interp_acc)
    
    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)

  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  end_time = time.time()
  rollout_time = end_time-start_time
  ground_truth_positions = ground_truth_positions[:, :len(predictions)]
  predictions = torch.stack(predictions)
  ground_truth_positions = ground_truth_positions.permute(1, 0, 2)

  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  # breakpoint()
  if torch.isnan(predictions).any():
      print('NAN WARNING')
  # ground_truth_positions = torch.clip(ground_truth_positions, 0.1, 0.9)
  grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions)
  
  # breakpoint()
  epsilon = 1e-6
  # mask = grid_m_gt > 0
  # loss = torch.zeros_like(grid_m_gt)
  # loss[mask] = ((grid_m[mask] - grid_m_gt[mask]) ** 2) / (grid_m_gt[mask] ** 2 + epsilon)
  # mask threshold
  # breakpoint()
  # threshold = 1
  # mask = grid_m > ((grid_m.sum()/grid_m.shape[0])/predictions.shape[1])*threshold
  # loss = torch.zeros_like(grid_m_gt)
  # loss[mask] = ((grid_m[mask] - grid_m_gt[mask]) ** 2) / (grid_m_gt[mask] ** 2 + epsilon)

  loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)
  # loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)
  # # p_loss = loss[-988:]
  # loss_values = p_loss.mean(dim=(1, 2))
  # loss_array = loss_values.cpu().numpy()
  # np.save("loss_values_gns_water2d.npy", loss_array)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, rollout_time


def rollout_with_gt(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  N = 2  
  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]

  current_positions = initial_positions
  predictions = []
  velocitys = []
  accs = []
  edge_indexs = []

  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity
  predict_grids = []

  nsteps = len(ground_truth_indices)
  for step in tqdm(range(nsteps), total=nsteps):
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc, edge_index = simulator.predict_positions_index(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        predictions.append(interp_pos) 
        velocitys.append(interp_vel)
        accs.append(interp_acc)

    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)
    edge_indexs.append(edge_index)

    current_velocity = next_velocity
    current_acc = next_acc
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], ground_truth_positions[:, step:step+1, :]], dim=1)

  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  ground_truth_positions = ground_truth_positions[:, :len(predictions)]
  predictions = torch.stack(predictions)
  ground_truth_positions = ground_truth_positions.permute(1, 0, 2)

  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)
  # grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions)
  # epsilon = 1e-6 
  # loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, edge_indexs


def rollout_N(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  N = 10
  total_predictions = torch.zeros_like(position)
  total_predictions[:, :INPUT_SEQUENCE_LENGTH] = position[:, :INPUT_SEQUENCE_LENGTH]

  for offset in range(N):
      initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N + offset
      initial_indices = initial_indices[initial_indices < position.shape[1]]  
      
      if len(initial_indices) < INPUT_SEQUENCE_LENGTH:
          continue
      
      current_positions = position[:, initial_indices]
      ground_truth_start = initial_indices[-1] + 1
      gt_indices = np.arange(ground_truth_start, position.shape[1], N)
      
      for step in range(len(gt_indices)):
        # Get next position with shape (nnodes, dim)
        next_position, acc = simulator.predict_positions(
            current_positions,
            nparticles_per_example=[n_particles_per_example],
            particle_types=particle_types,
            material_property=material_property
        )

        # Update kinematic particles from prescribed trajectory.
        kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
        next_position_ground_truth = position[:, gt_indices[step]].to(device)
        kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
        next_position = torch.where(
            kinematic_mask, next_position_ground_truth, next_position)
        # predictions.append(next_position)
        target_idx = gt_indices[step]
        if target_idx < total_predictions.shape[1]:
            total_predictions[:, target_idx] = next_position.detach().cpu()
        # Shift `current_positions`, removing the oldest position in the sequence
        # and appending the next position at the end.
        current_positions = torch.cat(
            [current_positions[:, 1:], next_position[:, None, :]], dim=1)

  # Predictions with shape (time, nnodes, dim)
  predictions = total_predictions[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  ground_truth_positions = position[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  initial_positions = position[:, :N*INPUT_SEQUENCE_LENGTH]
  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss



def predict(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'
  breakpoint()
  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  total_time = []
  # import pdb; pdb.set_trace()
  # start_time = time.time()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, _, _, rollout_time = rollout(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      total_time.append(rollout_time)
      # breakpoint()
      print("rollout_time: {}".format(rollout_time))
      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # breakpoint()
      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  # end_time = time.time()
  # execution_time = end_time - start_time
  # print(f"Total runing time: {execution_time:.4f} second")
  print("Average_time: {}".format(sum(total_time)/len(total_time)))
  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_acc_cosine(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'

  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  all_x_vals = []
  all_y_vals = []
  cos_acc_all=[]
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, velocity ,acc, _= rollout(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      
      # plot loss-acc-curve
      # breakpoint()
      cos_acc=[]
      for t in range(1, len(acc)):
        similarity = torch.nn.functional.cosine_similarity(
            acc[t], acc[t-1], dim=1
        )
        cos_acc.append(torch.mean(similarity).item())
        cos_acc_all.append(torch.mean(similarity).item())
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot((normalized_loss[1:]).numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(cos_acc, label='Acceleration Cosine Similarity', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_acc_xy/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_acc_xy/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()

      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # correlation
      # breakpoint()
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(cos_acc, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))

      x_vals, y_vals = zip(*mean_points)
      all_x_vals.extend(x_vals)
      all_y_vals.extend(y_vals)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean Cosine Accuracy (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_acc_xy/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()
      
      # breakpoint()
      # from scipy.stats import spearmanr
      rho, pval = spearmanr(all_x_vals, all_y_vals)
      print(f"Spearman correlation coefficient: {rho:.4f}")
      print(f"P-value: {pval:.4e}")
      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_acc/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")


      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()
  # breakpoint()
  plt.figure(figsize=(10, 8))
  plt.scatter(all_x_vals, all_y_vals, color='blue', s=1)  # alpha=0.5 ，s=5
  plt.xlabel('Mean Cosine Accuracy (per window)')
  plt.ylabel('Mean Normalized Loss (per window)')
  plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss (All examples)')
  plt.grid(True)
  plt.savefig(f"loss_curve_acc_xy/{path_name}/all_examples_correlations.png", dpi=300, bbox_inches="tight")
  plt.close()
  
  median = statistics.median(cos_acc_all)
  print(":", median)
  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))
  

def predict_velocity_divergence_KD(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'

  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, velocities,accs, edge_indexs = rollout_with_gt(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      
      frame_losses = []
      num_frames, num_particles, dim = example_rollout['predicted_rollout'].shape
      from scipy.spatial import KDTree
      for t in range(num_frames):
        pos = example_rollout['predicted_rollout'][t]
        vel = velocities[t].cpu().numpy()
        
        # KDTree
        tree = KDTree(pos)
        divergences = np.zeros(num_particles)
        
        # 
        for i in range(num_particles):
            # k（）
            _, neighbor_indices = tree.query(pos[i], k=5) #k_neighbors=5
            neighbor_pos = pos[neighbor_indices]
            neighbor_vel = vel[neighbor_indices]
            # 
            relative_pos = neighbor_pos - pos[i]
            relative_vel = neighbor_vel - vel[i]
            # 
            if len(relative_pos) < 2:
                grad_u = np.zeros(dim)
                grad_v = np.zeros(dim)
            else:
                #  (u, v)
                grad_u, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 0], rcond=None)
                grad_v, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 1], rcond=None)
            
            divergences[i] = grad_u[0] + grad_v[1]
        
        div_loss = np.mean(divergences ** 2)
        frame_losses.append(div_loss)
      
      # breakpoint()
      frame_losses = np.array(frame_losses)
      normalized_frame_loss = (frame_losses - frame_losses.min()) / (frame_losses.max() - frame_losses.min())
      # normalize
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      # plot loss-acc-curve
      # cos_acc=[]
      # for t in range(1, len(acc)):
      #   similarity = torch.nn.functional.cosine_similarity(
      #       acc[t], acc[t-1], dim=1
      #   )
      #   cos_acc.append(torch.mean(similarity).item())
      # loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      # normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot(normalized_loss.numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(normalized_frame_loss, label='velocity divergence loss', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_velocity_divergence_kd/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()

      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(normalized_frame_loss, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))

      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean velocity divergence KD (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()
      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_velocity_divergence_kd/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_velocity_divergence_edge(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'

  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      example_rollout, loss, velocities,accs, edge_indexs = rollout_with_gt(simulator,
                                      positions,
                                      particle_type,
                                      material_property,
                                      n_particles_per_example,
                                      nsteps,
                                      device)
      
      frame_losses = []
      num_frames, num_particles, dim = example_rollout['predicted_rollout'].shape
      from scipy.spatial import KDTree
      for t in range(num_frames):
        pos = example_rollout['predicted_rollout'][t]
        vel = velocities[t].cpu().numpy()
        edge_index = edge_indexs[t].cpu().numpy() 

        divergences = np.zeros(num_particles)
        
        # 
        # breakpoint()
        for i in range(num_particles):
            # k（）
            neighbor_mask = edge_index[0] == i
            neighbor_indices = edge_index[1][neighbor_mask]

            if len(neighbor_indices) < dim + 1:
                divergences[i] = 0.0
                continue

            relative_pos = pos[neighbor_indices] - pos[i]      # (k, dim)
            relative_vel = vel[neighbor_indices] - vel[i]      # (k, dim)
            # 
            if len(relative_pos) < 2:
                grad_u = np.zeros(dim)
                grad_v = np.zeros(dim)
            else:
                #  (u, v)
                grad_u, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 0], rcond=None)
                grad_v, _, _, _ = np.linalg.lstsq(relative_pos, relative_vel[:, 1], rcond=None)
            
            divergences[i] = grad_u[0] + grad_v[1]
        
        div_loss = np.mean(divergences ** 2)
        frame_losses.append(div_loss)
      
      # breakpoint()
      frame_losses = np.array(frame_losses)
      normalized_frame_loss = (frame_losses - frame_losses.min()) / (frame_losses.max() - frame_losses.min())
      # normalize
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      # plot loss-acc-curve
      # cos_acc=[]
      # for t in range(1, len(acc)):
      #   similarity = torch.nn.functional.cosine_similarity(
      #       acc[t], acc[t-1], dim=1
      #   )
      #   cos_acc.append(torch.mean(similarity).item())
      # loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      # normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot(normalized_loss.numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(normalized_frame_loss, label='velocity divergence loss', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_velocity_divergence_edge/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()
      
      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(normalized_frame_loss, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))
      
      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean velocity divergence edge (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()

      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_velocity_divergence_edge/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")

      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def optimizer_to(optim, device):
  for param in optim.state.values():
    # Not sure there are any global tensors in the state dict
    if isinstance(param, torch.Tensor):
      param.data = param.data.to(device)
      if param._grad is not None:
        param._grad.data = param._grad.data.to(device)
    elif isinstance(param, dict):
      for subparam in param.values():
        if isinstance(subparam, torch.Tensor):
          subparam.data = subparam.data.to(device)
          if subparam._grad is not None:
            subparam._grad.data = subparam._grad.data.to(device)

def acceleration_loss(pred_acc, target_acc, non_kinematic_mask):
  """
  Compute the loss between predicted and target accelerations.

  Args:
    pred_acc: Predicted accelerations.
    target_acc: Target accelerations.
    non_kinematic_mask: Mask for kinematic particles.
  """
  loss = (pred_acc - target_acc) ** 2
  loss = loss.sum(dim=-1)
  num_non_kinematic = non_kinematic_mask.sum()
  loss = torch.where(non_kinematic_mask.bool(),
                    loss, torch.zeros_like(loss))
  loss = loss.sum() / num_non_kinematic
  return loss

def save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer,
                                train_loss, valid_loss, train_loss_hist, valid_loss_hist):
  """Save model state
  
  Args:
    rank: local rank
    device: torch device type
    simulator: Trained simulator if not will undergo training.
    flags: flags
    step: step
    epoch: epoch
    optimizer: optimizer
    train_loss: training loss at current step
    valid_loss: validation loss at current step
    train_loss_hist: training loss history at each epoch
    valid_loss_hist: validation loss history at each epoch
  """
  if rank == 0 or device == torch.device("cpu"):
      if device == torch.device("cpu"):
          simulator.save(flags["model_path"] + 'model-' + str(step) + '.pt')
      else:
          simulator.module.save(flags["model_path"] + 'model-' + str(step) + '.pt')

      train_state = dict(optimizer_state=optimizer.state_dict(),
                          global_train_state={
                            "step": step, 
                            "epoch": epoch,
                            "train_loss": train_loss,
                            "valid_loss": valid_loss
                            },
                          loss_history={"train": train_loss_hist, "valid": valid_loss_hist}
                          )
      torch.save(train_state, f'{flags["model_path"]}train_state-{step}.pt')
      
def train(rank, flags, world_size, device):
  """Train the model.

  Args:
    rank: local rank
    world_size: total number of ranks
    device: torch device type
  """
  if device == torch.device("cuda"):
    distribute.setup(rank, world_size, device)
    device_id = rank
  else:
    device_id = device

  # import pdb; pdb.set_trace()
  # Read metadata
  metadata = reading_utils.read_metadata(flags["data_path"], "train")

  # Get simulator and optimizer
  if device == torch.device("cuda"):
    serial_simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], rank)
    simulator = DDP(serial_simulator.to(rank), device_ids=[rank], output_device=rank)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"]*world_size)
  else:
    simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], device)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"] * world_size)

  # Initialize training state
  step = 0
  epoch = 0
  steps_per_epoch = 0

  valid_loss = None
  epoch_train_loss = 0
  epoch_valid_loss = None

  train_loss_hist = []
  valid_loss_hist = []

  # If model_path does exist and model_file and train_state_file exist continue training.
  if flags["model_file"] is not None:

    if flags["model_file"] == "latest" and flags["train_state_file"] == "latest":
      # find the latest model, assumes model and train_state files are in step.
      fnames = glob.glob(f'{flags["model_path"]}*model*pt')
      max_model_number = 0
      expr = re.compile(".*model-(\d+).pt")
      for fname in fnames:
        model_num = int(expr.search(fname).groups()[0])
        if model_num > max_model_number:
          max_model_number = model_num
      # reset names to point to the latest.
      flags["model_file"] = f"model-{max_model_number}.pt"
      flags["train_state_file"] = f"train_state-{max_model_number}.pt"

    if os.path.exists(flags["model_path"] + flags["model_file"]) and os.path.exists(flags["model_path"] + flags["train_state_file"]):
      # load model
      if device == torch.device("cuda"):
        simulator.module.load(flags["model_path"] + flags["model_file"])
      else:
        simulator.load(flags["model_path"] + flags["model_file"])

      # load train state
      train_state = torch.load(flags["model_path"] + flags["train_state_file"])
      
      # set optimizer state
      optimizer = torch.optim.Adam(
        simulator.module.parameters() if device == torch.device("cuda") else simulator.parameters())
      optimizer.load_state_dict(train_state["optimizer_state"])
      optimizer_to(optimizer, device_id)
      
      # set global train state
      step = train_state["global_train_state"]["step"]
      epoch = train_state["global_train_state"]["epoch"]
      train_loss_hist = train_state["loss_history"]["train"]
      valid_loss_hist = train_state["loss_history"]["valid"]

    else:
      msg = f'Specified model_file {flags["model_path"] + flags["model_file"]} and train_state_file {flags["model_path"] + flags["train_state_file"]} not found.'
      raise FileNotFoundError(msg)

  simulator.train()
  simulator.to(device_id)

  # Get data loader
  get_data_loader = (
    distribute.get_data_distributed_dataloader_by_samples
    if device == torch.device("cuda")
    else data_loader.get_data_loader_by_samples
  )

  # Load training data
  # import pdb; pdb.set_trace()
  # train_npz = f'{flags["data_path"]}train.npz'
  # valid_npz = f'{flags["data_path"]}valid.npz'
  dl = get_data_loader(
      path=f'{flags["data_path"]}train.npz',
      input_length_sequence=INPUT_SEQUENCE_LENGTH,
      batch_size=flags["batch_size"],
  )
  n_features = len(dl.dataset._data[0])

  # Load validation data
  if flags["validation_interval"] is not None:
      dl_valid = get_data_loader(
          path=f'{flags["data_path"]}valid.npz',
          input_length_sequence=INPUT_SEQUENCE_LENGTH,
          batch_size=flags["batch_size"],
      )
      if len(dl_valid.dataset._data[0]) != n_features:
          raise ValueError(
              f"`n_features` of `valid.npz` and `train.npz` should be the same"
          )

  print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")
  import time
  start_time = time.time()

  try:
    while step < flags["ntraining_steps"]:
      if device == torch.device("cuda"):
        torch.distributed.barrier()

      for example in dl:  
        # import pdb; pdb.set_trace()
        steps_per_epoch += 1
        # ((position, particle_type, material_property, n_particles_per_example), labels) are in dl
        position = example[0][0].to(device_id)
        particle_type = example[0][1].to(device_id)
        if n_features == 3:  # if dl includes material_property
          material_property = example[0][2].to(device_id)
          n_particles_per_example = example[0][3].to(device_id)
        elif n_features == 2:
          n_particles_per_example = example[0][2].to(device_id)
        else:
          raise NotImplementedError
        labels = example[1].to(device_id)

        n_particles_per_example.to(device_id)
        labels.to(device_id)

        # TODO (jpv): Move noise addition to data_loader
        # Sample the noise to add to the inputs to the model during training.
        sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(position, noise_std_last_step=flags["noise_std"]).to(device_id)
        non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
        sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

        # Get the predictions and target accelerations
        device_or_rank = rank if device == torch.device("cuda") else device
        pred_acc, target_acc = (simulator.module.predict_accelerations if device == torch.device("cuda") else simulator.predict_accelerations)(
            next_positions=labels.to(device_or_rank),
            position_sequence_noise=sampled_noise.to(device_or_rank),
            position_sequence=position.to(device_or_rank),
            nparticles_per_example=n_particles_per_example.to(device_or_rank),
            particle_types=particle_type.to(device_or_rank),
            material_property=material_property.to(device_or_rank) if n_features == 3 else None
        )
        
        # Validation
        if flags["validation_interval"] is not None:
          sampled_valid_example = next(iter(dl_valid))
          if step > 0 and step % flags["validation_interval"] == 0:
              valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
              print(f"Validation loss at {step}: {valid_loss.item()}")

        # Calculate the loss and mask out loss on kinematic particles
        loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

        train_loss = loss.item()
        epoch_train_loss += train_loss

        # Computes the gradient of loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update learning rate
        lr_new = flags["lr_init"] * (flags["lr_decay"] ** (step/flags["lr_decay_steps"])) * world_size
        for param in optimizer.param_groups:
          param['lr'] = lr_new
     
        print(f'rank = {rank}, epoch = {epoch}, step = {step}/{flags["ntraining_steps"]}, loss = {train_loss}', flush=True)

        # Save model state
        if rank == 0 or device == torch.device("cpu"):
          if step % flags["nsave_steps"] == 0:
            save_model_and_train_state(rank, device, simulator, flags, step, epoch, 
                                       optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

        step += 1
        if step >= flags["ntraining_steps"]:
            break

      # Epoch level statistics
      # Training loss at epoch
      epoch_train_loss /= steps_per_epoch
      epoch_train_loss = torch.tensor([epoch_train_loss]).to(device_id)
      if device == torch.device("cuda"):
        torch.distributed.reduce(epoch_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        epoch_train_loss /= world_size

      train_loss_hist.append((epoch, epoch_train_loss.item()))

      # Validation loss at epoch
      if flags["validation_interval"] is not None:
        sampled_valid_example = next(iter(dl_valid))
        epoch_valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
        if device == torch.device("cuda"):
          torch.distributed.reduce(epoch_valid_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
          epoch_valid_loss /= world_size

        valid_loss_hist.append((epoch, epoch_valid_loss.item()))

      # Print epoch statistics
      if rank == 0 or device == torch.device("cpu"):
        print(f'Epoch {epoch}, training loss: {epoch_train_loss.item()}')
        if flags["validation_interval"] is not None:
          print(f'Epoch {epoch}, validation loss: {epoch_valid_loss.item()}')
      
      # Reset epoch training loss
      epoch_train_loss = 0
      if steps_per_epoch >= len(dl):
        epoch += 1
      steps_per_epoch = 0
      
      if step >= flags["ntraining_steps"]:
        break 
      
  except KeyboardInterrupt:
    pass

  # Save model state on keyboard interrupt
  end_time = time.time()
  total_time = end_time - start_time
  print(f"：{total_time:.2f} ")

  save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

  if torch.cuda.is_available():
    distribute.cleanup()


def _get_simulator(
        metadata: json,
        acc_noise_std: float,
        vel_noise_std: float,
        device: torch.device) -> learned_simulator.LearnedSimulator:
  """Instantiates the simulator.

  Args:
    metadata: JSON object with metadata.
    acc_noise_std: Acceleration noise std deviation.
    vel_noise_std: Velocity noise std deviation.
    device: PyTorch device 'cpu' or 'cuda'.
  """
  # import pdb; pdb.set_trace()
  # Normalization stats
  normalization_stats = {
      'acceleration': {
          'mean': torch.FloatTensor(metadata['acc_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['acc_std'])**2 +
                            acc_noise_std**2).to(device),
      },
      'velocity': {
          'mean': torch.FloatTensor(metadata['vel_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['vel_std'])**2 +
                            vel_noise_std**2).to(device),
      },
  }

  # Get necessary parameters for loading simulator.
  if "nnode_in" in metadata and "nedge_in" in metadata:
    nnode_in = metadata['nnode_in']
    nedge_in = metadata['nedge_in']
  else:
    # Given that there is no additional node feature (e.g., material_property) except for:
    # (position (dim), velocity (dim*6), particle_type (16)),
    nnode_in = 37 if metadata['dim'] == 3 else 30
    nedge_in = metadata['dim'] + 1

  # Init simulator.
  simulator = learned_simulator.LearnedSimulator(
      particle_dimensions=metadata['dim'],
      nnode_in=nnode_in,
      nedge_in=nedge_in,
      latent_dim=128,
      nmessage_passing_steps=10,
      nmlp_layers=2,
      mlp_hidden_dim=128,
      connectivity_radius=metadata['default_connectivity_radius'],
      boundaries=np.array(metadata['bounds']),
      normalization_stats=normalization_stats,
      nparticle_types=NUM_PARTICLE_TYPES,
      particle_type_embedding_size=16,
      boundary_clamp_limit=metadata["boundary_augment"] if "boundary_augment" in metadata else 1.0,
      device=device)

  return simulator

def validation(
        simulator,
        example,
        n_features,
        flags,
        rank,
        device_id):

  position = example[0][0].to(device_id)
  particle_type = example[0][1].to(device_id)
  if n_features == 3:  # if dl includes material_property
    material_property = example[0][2].to(device_id)
    n_particles_per_example = example[0][3].to(device_id)
  elif n_features == 2:
    n_particles_per_example = example[0][2].to(device_id)
  else:
    raise NotImplementedError
  labels = example[1].to(device_id)

  # Sample the noise to add to the inputs.
  sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
    position, noise_std_last_step=flags["noise_std"]).to(device_id)
  non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
  sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

  # Do evaluation for the validation data
  device_or_rank = rank if isinstance(device_id, int) else device_id
  # Select the appropriate prediction function
  predict_accelerations = simulator.module.predict_accelerations if isinstance(device_id, int) else simulator.predict_accelerations
  # Get the predictions and target accelerations
  with torch.no_grad():
      pred_acc, target_acc = predict_accelerations(
          next_positions=labels.to(device_or_rank),
          position_sequence_noise=sampled_noise.to(device_or_rank),
          position_sequence=position.to(device_or_rank),
          nparticles_per_example=n_particles_per_example.to(device_or_rank),
          particle_types=particle_type.to(device_or_rank),
          material_property=material_property.to(device_or_rank) if n_features == 3 else None
      )

  # Compute loss
  loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

  return loss


def main(_):
  """Train or evaluates the model.

  """
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  if device == torch.device('cuda'):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29504"

  myflags = reading_utils.flags_to_dict(FLAGS)

  if FLAGS.mode == 'train':
    # If model_path does not exist create new directory.
    if not os.path.exists(FLAGS.model_path):
      os.makedirs(FLAGS.model_path)
    # import pdb; pdb.set_trace()
    # Train on gpu 
    if device == torch.device('cuda'):
      available_gpus = torch.cuda.device_count()
      print(f"Available GPUs = {available_gpus}")

      # Set the number of GPUs based on availability and the specified number
      if FLAGS.n_gpus is None or FLAGS.n_gpus > available_gpus:
        world_size = available_gpus
        if FLAGS.n_gpus is not None:
          print(f"Warning: The number of GPUs specified ({FLAGS.n_gpus}) exceeds the available GPUs ({available_gpus})")
      else:
        world_size = FLAGS.n_gpus

      # Print the status of GPU usage
      print(f"Using {world_size}/{available_gpus} GPUs")

      # Spawn training to GPUs
      # import pdb; pdb.set_trace()
      # os.environ["MASTER_PORT"] = "29501"
      distribute.spawn_train(train, myflags, world_size, device)

    # Train on cpu  
    else:
      rank = None
      world_size = 1
      train(rank, myflags, world_size, device)

  elif FLAGS.mode in ['valid', 'rollout']:
    # Set device
    world_size = torch.cuda.device_count()
    if FLAGS.cuda_device_number is not None and torch.cuda.is_available():
      device = torch.device(f'cuda:{int(FLAGS.cuda_device_number)}')
    #test code
    print(f"device is {device} world size is {world_size}")
    predict(device)


if __name__ == '__main__':
  app.run(main)
