import collections
import json
import os
import pickle
import glob
import re
import sys

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm

from absl import flags
from absl import app

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from gns_mpm_control import learned_simulator
from gns_mpm_control import noise_utils
from gns_mpm_control import reading_utils
from gns_mpm_control import data_loader
from gns_mpm_control import distribute

from Gnn_ddim import learned_simulator as learned_simulator_control
from mpm_engine.phy_solver import MPM_Simulator

import matplotlib.pyplot as plt
from p2g_utils_3d import run_simulation, run_simulation_mpm
from p2g_utils_sand3d import run_simulation_mpm_thershold
import time

flags.DEFINE_enum(
    'mode', 'train', ['train', 'valid', 'rollout'],
    help='Train model, validation or rollout evaluation.')
flags.DEFINE_integer('batch_size', 1, help='The batch size.')
flags.DEFINE_float('noise_std', 6.7e-4, help='The std deviation of the noise.')
flags.DEFINE_string('data_path', None, help='The dataset directory.')
flags.DEFINE_string('data_path_dense', None, help='The dataset directory.')
flags.DEFINE_string('model_path', 'models/', help=('The path for saving checkpoints of the model.'))
flags.DEFINE_string('model_path_control', 'models/', help=('The path for saving checkpoints of the model.'))
flags.DEFINE_string('output_path', 'rollouts/', help='The path for saving outputs (e.g. rollouts).')
flags.DEFINE_string('image_path', 'rollouts/', help='The path for saving outputs (e.g. rollouts).')
flags.DEFINE_string('output_filename', 'rollout', help='Base name for saving the rollout')
flags.DEFINE_string('model_file', None, help=('Model filename (.pt) to resume from. Can also use "latest" to default to newest file.'))
flags.DEFINE_string('train_state_file', 'train_state.pt', help=('Train state filename (.pt) to resume from. Can also use "latest" to default to newest file.'))

flags.DEFINE_integer('ntraining_steps', int(2E7), help='Number of training steps.')
flags.DEFINE_integer('validation_interval', None, help='Validation interval. Set `None` if validation loss is not needed')
flags.DEFINE_integer('nsave_steps', int(100000), help='Number of steps at which to save the model.')

# Learning rate parameters
flags.DEFINE_float('lr_init', 1e-4, help='Initial learning rate.')
flags.DEFINE_float('lr_decay', 0.1, help='Learning rate decay.')
flags.DEFINE_integer('lr_decay_steps', int(5e6), help='Learning rate decay steps.')

flags.DEFINE_integer("cuda_device_number", None, help="CUDA device (zero indexed), default is None so default CUDA device will be used.")
flags.DEFINE_integer("n_gpus", 1, help="The number of GPUs to utilize for training.")

FLAGS = flags.FLAGS

Stats = collections.namedtuple('Stats', ['mean', 'std'])

INPUT_SEQUENCE_LENGTH = 6  # So we can calculate the last 5 velocities.
NUM_PARTICLE_TYPES = 9
KINEMATIC_PARTICLE_ID = 3

from bolt.upsample_2d import random_sampling, octree_initial
from bolt.utils import knn_smoothing, bilateral_smoothing

# from p2g_utils import run_simulation

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from utils.draw_control_plot import create_control_figure_3d

def plot_figure(first_positions, new_pos, index):
       #  ()
    plt.subplot(1, 2, 1)
    plt.scatter(first_positions[:, 0], first_positions[:, 1], c='blue', s=1)
    plt.title('Original Points')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.xlim(0, 1)  
    plt.ylim(0, 1)  

    #  ()
    plt.subplot(1, 2, 2)
    plt.scatter(new_pos[:, 0], new_pos[:, 1], c='red', s=1)
    plt.title('Sampled Points')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.tight_layout()
    plt.xlim(0, 1)  
    plt.ylim(0, 1) 

    output_file_path = f"./thershold_gns_mpm/scatter_plot_left_right_{index}.png"
    plt.savefig(output_file_path)

from torch_cluster import fps

def sample_points(x, num_samples):
    """
    x: (n, dim) tensor
    num_samples: 
    : (num_samples, dim) tensor
    """
    batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
    ratio = (num_samples/x.shape[0])
    idx = fps(x, batch, ratio=ratio)
    return x[idx]

def grid_sample_2d(points, grid_size=0.05):
    # points: [N, 2]
    coords = (points / grid_size).floor().int()
    unique_coords, indices = torch.unique(coords, dim=0, return_inverse=True)

    new_points = []
    for i in range(unique_coords.shape[0]):
        mask = (indices == i)
        new_points.append(points[mask].mean(dim=0))
    return torch.stack(new_points)

def random_knn(
        position_sparse: torch.tensor,
        n_initial_sample: int,
        n_final_sample: int):
  position_sparse = position_sparse.cpu().numpy()
  new_pts = random_sampling(position_sparse, n_final_sample - n_initial_sample)
  dense_sample = knn_smoothing(position_sparse, new_pts, k=10)
  return dense_sample

MATERIAL_TYPE_ID={"water": 5, "sand": 6, "boundary": 3}
def check_material(particle_types):
    has_water = torch.any(particle_types == MATERIAL_TYPE_ID["water"])
    has_sand = torch.any(particle_types == MATERIAL_TYPE_ID["sand"])

    if has_water:
      material = "water"
      return material
    elif has_sand:
      material = "sand"
      return material
# def acc_vel_diff(positions):
#   current_velocity = positions[-1] - positions[-2]
#   pre_velocity = positions[-2] - positions[-3]
#   current_acc = current_velocity - pre_velocity

#   return current_acc, current_velocity
from torchvision import transforms
from PIL import Image
def make_transformations(resolution, type, ori_resolution=None):
    """ 
    resolution: target resolution, a list of int, [h, w]
    """
    if type == "random_crop":
        transformations = transforms.RandomCropss(resolution)
    elif type == "resize_center_crop":
        is_square = (resolution[0] == resolution[1])
        if is_square:
            transformations = transforms.Compose([
                transforms.Resize(resolution[0]),
                transforms.CenterCrop(resolution[0]),
                transforms.ToTensor(),
                ])
        else:
            if ori_resolution is not None:
                # resize while keeping original aspect ratio,
                # then centercrop to target resolution
                resize_ratio = max(resolution[0] / ori_resolution[0], resolution[1] / ori_resolution[1])
                resolution_after_resize = [int(ori_resolution[0] * resize_ratio), int(ori_resolution[1] * resize_ratio)]
                transformations = transforms.Compose([
                    transforms.Resize(resolution_after_resize),
                    transforms.CenterCrop(resolution),
                    transforms.ToTensor(),
                    ])
            else:
                # directly resize to target resolution
                transformations = transforms.Compose([
                    transforms.Resize(resolution),
                    transforms.ToTensor(),
                    ])
    elif type == "align2_256":
        is_square = (resolution[0] == resolution[1])
        if is_square:
            transformations = transforms.Compose([
                transforms.Resize(resolution[0]),
                transforms.CenterCrop(resolution[0]),
                transforms.ToTensor(),
                ])
        else:
            transformations = transforms.Compose([
                transforms.Resize(max(resolution)),
                transforms.CenterCrop(resolution),
                transforms.ToTensor(),
                ])
    else:
        raise NotImplementedError
    return transformations

def rollout(
        simulator: learned_simulator.LearnedSimulator,
        simulator_control: learned_simulator_control.LearnedSimulator_control,
        position: torch.tensor,
        position_dense: torch.tensor,
        particle_types: torch.tensor,
        particle_types_dense: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        n_particles_per_example_dense: int,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  # breakpoint()
  N = 2  
  rollout_length = position.shape[1] - N * 6

  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 
  initial_positions_dense = position_dense[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]
  ground_truth_positions_dense = position_dense[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]
  # control ddim init
  material_type = check_material(particle_types)
  cotrol_time_step = torch.tensor([INPUT_SEQUENCE_LENGTH -1]).to(device).float() 
  batch = torch.zeros(position.shape[0], dtype=torch.int64, device=device)
  # max_acc = torch.tensor([0.3188434565532966, 0.35329655583853925, 0.15904730577468076], device=device)
  # min_acc = torch.tensor([-0.33034850361850626, -0.2306089404407259, -0.14534414167456292], device=device)
  max_acc = torch.tensor([2045.0929, 1716.3162, 5247.1494], device=device)
  min_acc = torch.tensor([-5331.2910, -1660.3076, -1601.4119], device=device)

  current_positions = initial_positions
  predictions = []
  predictions_dense = []
  velocitys = []
  accs = []
  cos_acc=[]
  minus_times = []
  # predict_grids = []
  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity

  nsteps = len(ground_truth_indices)*N
  # for step in tqdm(range(nsteps), total=nsteps):
  count_nan = 0
  step = 0
  # start_time = time.time()
  time_step_gns = 0
  time_step_mpm = 0
  # breakpoint()
  while step < 50:
    # tensorrt
    # breakpoint()
    # latency = compute_latency_ms_pytorch(simulator, current_positions, n_particles_per_example, particle_types, material_property, iterations=rollout_length)
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc = simulator.predict_positions(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        # interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        # interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        
        predictions.append(interp_pos) 
        # velocitys.append(interp_vel)
        # accs.append(interp_acc) 
    
    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)

    # current_velocity = next_velocity
    # current_acc = next_acc
    step += N
    time_step_gns += 1
    # if len(accs) > 1:      
    #   similarity = torch.nn.functional.cosine_similarity(
    #       accs[-1], accs[-2], dim=1
    #   )
    #   cos_acc.append(torch.mean(similarity).item())

    #     continue
    # next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
    # next_position_dense = torch.tensor(next_position_dense, device=device)
    # predictions_dense.append(next_position_dense)
    
    # breakpoint()
    # predict_grid = particle_to_grid(next_position_dense, n_particles_per_example_dense)
    # predict_grids.append(predict_grid)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)
  
    
  # Tirs-time mpm
  print('Go taichi simulate')
  next_velocity = next_position - interp_pos
  mpm_step = 100
  future_step50, future_velocities = run_simulation_mpm_thershold(next_position, next_velocity, mpm_step)
  # minus_times.append(minus_time)
  future_tensors = [torch.tensor(arr, device=device) for arr in future_step50]
  predictions += future_tensors
  future_step50 = np.array(future_step50)
  end_idx = future_step50.shape[0]
  indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
  indices = indices.flip(0).cpu().numpy()
  current_positions = torch.tensor(future_step50[indices]).cuda().permute(1,0,2)
  cos_acc = []
  accs = []
  step += mpm_step
  if step > nsteps:
    time_step_mpm += (mpm_step-(step-nsteps))
  else:
    time_step_mpm += mpm_step
    # current_velocity = torch.tensor(future_step50[-1] - future_step50[-2]).cuda()
    # pre_velocity = torch.tensor(future_step50[-2] - future_step50[-3]).cuda()
    # current_acc = current_velocity - pre_velocity
  
  if current_positions.shape[-1]==2:
    sim = MPM_Simulator(
          max_particles=2000,
          min_particles=100,
          n_grid=128,
          dt=2.5e-4,
          gravity=9.8,
          dim=2,
          lower_bound=0.1,
          upper_bound=0.9,
          bound=0.115,
          material_type=material_type,
          obstacle=None
      )
  else:
    sim = MPM_Simulator(
          max_particles=2285,
          min_particles=100,
          n_grid=64,
          dt=2e-4,
          gravity=9.8,
          dim=3,
          lower_bound=0.125,
          upper_bound=0.875,
          bound=0.15,
          material_type=material_type,
    )
  # breakpoint()
  v = (initial_positions[:,-1,:]-initial_positions[:,-2,:])/simulator_control.dt
  sim.init_particles(x=current_positions[:,-1,:].cpu().numpy(), v=v.cpu().numpy())
  resolution=[512,512]
  transform_type='align2_256'

  breakpoint()
  target_positions = np.array([0.5, 0.5, 0.5])
  control_img = create_control_figure_3d(current_positions[:,-1,:].cpu().numpy(), target_positions, 0.125, 0.875, 'control_arrow' )
  for step in range(100):
    acc, norm_acc = simulator_control.predict_positions_light(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property,
        # velocity_sequence=current_velocity,
        control_img = control_img,
        cotrol_time_step=cotrol_time_step,
        batch=batch
    )
    cotrol_time_step = cotrol_time_step + 1
    # control_acc[:, step]
    # max_acc = torch.max(control_acc, dim=0)[0]
    # min_acc = torch.min(control_acc, dim=0)[0]
    # breakpoint()
    acc_clipped = torch.clamp(acc, min=min_acc, max=max_acc)
    sim.control_accel.from_numpy(acc_clipped.squeeze().cpu().numpy())
    for i in range(10):
      sim.substep(apply_control=1, force_control=0)
    next_position = sim.x.to_numpy()
    next_velocity = sim.v.to_torch().cuda()* simulator_control.dt
    # pose.append(next_position)
    next_position = torch.from_numpy(next_position).cuda()
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    predictions.append(next_position)
    # predict_acc.append(acc.squeeze(0))
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)
    current_velocity = torch.cat(
        [current_positions[:, 1:], next_velocity[:, None, :]], dim=1)

  # Second time mpm
  # breakpoint()
  print('Go taichi simulate')
  # next_velocity = predictions[-1] - predictions[-2]
  mpm_step = nsteps - len(predictions)
  future_step50, future_velocities = run_simulation_mpm_thershold(next_position, next_velocity, mpm_step)
  # minus_times.append(minus_time)
  future_tensors = [torch.tensor(arr, device=device) for arr in future_step50]
  predictions += future_tensors
  future_step50 = np.array(future_step50)
  end_idx = future_step50.shape[0]
  indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
  indices = indices.flip(0).cpu().numpy()
  current_positions = torch.tensor(future_step50[indices]).cuda().permute(1,0,2)
  cos_acc = []
  accs = []
  step += mpm_step
  if step > nsteps:
    time_step_mpm += (mpm_step-(step-nsteps))
  else:
    time_step_mpm += mpm_step


  # breakpoint()
  # end_time = time.time()
  # rollout_time = end_time-start_time
  rollout_time = (time_step_gns * 0.8044 + time_step_mpm * 0.986969) / 1000
  # Predictions with shape (time, nnodes, dim)
  ground_truth_positions_dense = ground_truth_positions_dense[:, :rollout_length]
  # predictions_dense = torch.stack(predictions_dense)
  predictions = predictions[:rollout_length]
  predictions = torch.stack(predictions)
  # breakpoint()
  # predict_grids = torch.stack(predict_grids)
  
  ground_truth_positions_dense = ground_truth_positions_dense.permute(1, 0, 2)
  
  if torch.isnan(predictions).any(): 
    loss = torch.tensor(0.0)
    count_nan += 1
  else:
    # breakpoint()
    grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions_dense)
    
    # loss = (predictions - ground_truth_positions) ** 2
    # rmse
    # loss = ((predictions - ground_truth_positions_dense) ** 2)/((ground_truth_positions_dense) ** 2)
    epsilon = 1e-6 
    loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)

  # breakpoint()
  # if len(minus_times) >= 2:
  #   time_to_subtract = sum(minus_times[1:])  # 
  #   rollout_time = rollout_time - time_to_subtract
  # else:
  #     rollout_time = rollout_time

  output_dict = {
      'initial_positions': initial_positions_dense.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions_dense.cpu().numpy(),
      'particle_types': particle_types_dense.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, rollout_time, count_nan

def rollout_mpm(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        position_dense: torch.tensor,
        particle_types: torch.tensor,
        particle_types_dense: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        n_particles_per_example_dense: int,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  # reducer = ParticleReducerTorchLite(target_particles=int(n_particles_per_example))
  N = 2  
  rollout_length = position.shape[1] - N * 6

  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 
  initial_positions_dense = position_dense[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]
  ground_truth_positions_dense = position_dense[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]
  current_positions = initial_positions
  predictions = []
  predictions_vis = []
  # predictions_vis.append(position[:, initial_indices[-1]-1]) 
  # predictions_vis.append(position[:, initial_indices[-1]]) 
  predictions_dense = []
  velocitys = []
  accs = []
  cos_acc = []
  count_nan = 0
  # predict_grids = []
  current_velocity = position[:, initial_indices[-1]] - position[:,initial_indices[-1]-1]
  pre_velocity = position[:,initial_indices[-1]-1] - position[:,initial_indices[-1]-2]
  current_acc = current_velocity - pre_velocity

  nsteps = len(ground_truth_indices)*N
  step = 0
  start_time = time.time()
  # for step in tqdm(range(nsteps), total=nsteps):
  while step < nsteps:
    # Get next position with shape (nnodes, dim)
    # breakpoint()
    next_position, next_velocity, next_acc = simulator.predict_positions(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    next_position = torch.clamp(next_position, min=0.1, max=0.9)
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        # interp_vel = current_velocity * (1 - alpha) + next_velocity * alpha
        # interp_acc = current_acc * (1 - alpha) + next_acc * alpha
        
        predictions.append(interp_pos) 
        predictions_vis.append(interp_pos)
        # interp_acc, interp_vel = acc_vel_diff(predictions_vis)
        # velocitys.append(interp_vel)
        # accs.append(interp_acc) 
    
    predictions_vis.append(next_position)
    predictions.append(next_position)
    
    # next_velocity, next_acc = acc_vel_diff(predictions_vis)
    velocitys.append(next_velocity)
    accs.append(next_acc)

    # current_velocity = next_velocity
    # current_acc = next_acc

    # print(next_position)
    # plot_figure(ground_truth_positions_dense[:, step].cpu().numpy(), next_position.cpu().numpy(), step)
    step += N

    # next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
    # next_position_dense = torch.tensor(next_position_dense, device=device)
    # predictions_dense.append(next_position_dense)
    if len(accs) > 1:      
      similarity = torch.nn.functional.cosine_similarity(
          accs[-1], accs[-2], dim=1
      )
      cos_acc.append(torch.mean(similarity).item())

    if len(cos_acc)>10: # go on mpm
      cos_seq = torch.tensor(cos_acc)
      cos_chunk = cos_seq[-10:]
      mean_cos = cos_chunk.mean().item()
      if mean_cos < 0.9: 
        # breakpoint()
        # check number of next_position
        print('Go taichi simulate')
        next_velocity = next_position - interp_pos
        next_position_dense = random_knn(next_position, int(n_particles_per_example), n_particles_per_example_dense)
        next_position_dense = torch.tensor(next_position_dense, device=device)
        future_step50, future_velocities, future_step50_sparse = run_simulation_mpm(next_position, next_velocity, next_position_dense)
        # concat
        future_tensors = [torch.tensor(arr, device=device) for arr in future_step50]
        future_tensors_sparse = [torch.tensor(arr, device=device) for arr in future_step50_sparse]
        predictions += future_tensors
        predictions_vis += future_tensors_sparse
        future_step50 = np.array(future_step50)
        future_step50_sparse = np.array(future_step50_sparse)
        end_idx = future_step50.shape[0]
        indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
        indices = indices.flip(0).cpu().numpy()
        current_positions = torch.tensor(future_step50_sparse[indices]).cuda().permute(1,0,2)
        # sparse_positions = []
        # for i in range(len(current_positions_dense)):
        #   new_pos = sample_points(current_positions_dense[i], int(n_particles_per_example))
        #   sparse_positions.append(new_pos)    
        #   # plot_figure(current_positions_dense[i].cpu().numpy(), new_pos.cpu().numpy(), step)
        #   # plot_figure(ground_truth_positions_dense[:, step].cpu().numpy(),new_pos.cpu().numpy(), step)
        # current_positions = torch.stack(sparse_positions).permute(1,0,2)
        # vis
        # breakpoint()
        # for i in range(50):
        #   vis_step = step + i + 1
        #   plot_figure(ground_truth_positions_dense[:, vis_step].cpu().numpy(), future_step50[i], vis_step)

        cos_acc = []
        accs = []
        step += 50
        # current_velocity = torch.tensor(future_step50_sparse[-1] - future_step50_sparse[-2]).cuda()
        # pre_velocity = torch.tensor(future_step50_sparse[-2] - future_step50_sparse[-3]).cuda()
        # current_acc = current_velocity - pre_velocity

        continue
      # predictions_dense.append(next_position_dense)
      # start_time = time.time()
      # future_step50, future_velocities = run_simulation_mpm(next_position, next_velocity, next_position_dense)
      # print(f"runing time: {(time.time()-start_time):.4f} second")
    # breakpoint()
    # predict_grid = particle_to_grid(next_position_dense, n_particles_per_example_dense)
    # predict_grids.append(predict_grid)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    # breakpoint()
    # end_idx = future_step50.shape[0]
    # indices = torch.arange(end_idx - 1, end_idx - 1 - N * INPUT_SEQUENCE_LENGTH, step=-N)
    # indices = indices.flip(0).cpu().numpy()

    # current_positions_dense = torch.tensor(future_step50[indices]).cuda()
    # current_vel_dense =  torch.tensor(future_velocities[indices]).cuda()

    # sparse_positions = []
    # for i in range(len(current_positions_dense)):
    #   breakpoint()
    #   start_time = time.time()
    #   # new_pos, new_vel = reducer.reduce_system(current_positions_dense[i], current_vel_dense[i])
    #   # new_pos = reducer.reduce_system(current_positions_dense[i])
    #   new_pos = sample_points(current_positions_dense[i], int(n_particles_per_example))
    #   print(f"runing time: {(time.time()-start_time):.4f} second")
    #   sparse_positions.append(new_pos)    
    # # current_positions = torch.tensor(np.array(sparse_positions)).permute(1,0,2).cuda()
    # current_positions = torch.stack(sparse_positions).permute(1,0,2)
    # breakpoint()
    current_positions = torch.cat(
        [current_positions[:, 1:], next_position[:, None, :]], dim=1)
  
  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  end_time = time.time()
  rollout_time = end_time-start_time
  # breakpoint()
  ground_truth_positions_dense = ground_truth_positions_dense[:, :rollout_length]
  predictions = predictions[:rollout_length]
  # predictions_vis = predictions_vis[2:rollout_length + 2]
  predictions_vis = torch.stack(predictions_vis)
  # predict_grids = torch.stack(predict_grids)
  ground_truth_positions_dense = ground_truth_positions_dense.permute(1, 0, 2)

  grid_m, grid_m_gt = run_simulation(predictions, ground_truth_positions_dense)
  
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  # loss = ((predictions_dense - ground_truth_positions_dense) ** 2)/((ground_truth_positions_dense) ** 2)
  epsilon = 1e-6 
  loss = ((grid_m -  grid_m_gt) ** 2)/((grid_m_gt) ** 2 + epsilon)
  
  if torch.isnan(loss).any(): 
    loss = torch.tensor(0.0)
    count_nan += 1

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'initial_positions_dense': initial_positions_dense.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions_vis.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions_dense.cpu().numpy(),
      'particle_types': particle_types_dense.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, rollout_time, count_nan

def rollout_with_gt(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  # import pdb; pdb.set_trace()
  # （）
  N = 1  
  initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N
  initial_positions = position[:, initial_indices] 

  ground_truth_start = initial_indices[-1] + N
  ground_truth_indices = np.arange(ground_truth_start, position.shape[1], N)
  # ground_truth_positions = position[:, ground_truth_indices]
  gt_initial = initial_indices[-1] + 1
  ground_truth_positions = position[:, gt_initial:]

  # initial_positions = position[:, :INPUT_SEQUENCE_LENGTH]
  # ground_truth_positions = position[:, INPUT_SEQUENCE_LENGTH:]

  current_positions = initial_positions
  predictions = []
  velocitys = []
  accs = []
  edge_indexs = []

  nsteps = len(ground_truth_indices)
  for step in tqdm(range(nsteps), total=nsteps):
    # Get next position with shape (nnodes, dim)
    next_position, next_velocity, next_acc, edge_index = simulator.predict_positions_index(
        current_positions,
        nparticles_per_example=[n_particles_per_example],
        particle_types=particle_types,
        material_property=material_property
    )
    # Update kinematic particles from prescribed trajectory.
    kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
    next_position_ground_truth = ground_truth_positions[:, step]
    kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
    next_position = torch.where(
        kinematic_mask, next_position_ground_truth, next_position)
    # interpolation
    interp_steps = np.linspace(0, 1, N+1)[1:-1]  # 
    for alpha in interp_steps:
        interp_pos = current_positions[:, -1] * (1-alpha) + next_position * alpha
        predictions.append(interp_pos) 

    predictions.append(next_position)
    velocitys.append(next_velocity)
    accs.append(next_acc)
    edge_indexs.append(edge_index)
    # Shift `current_positions`, removing the oldest position in the sequence
    # and appending the next position at the end.
    current_positions = torch.cat(
        [current_positions[:, 1:], ground_truth_positions[:, step:step+1, :]], dim=1)

  # breakpoint()
  # Predictions with shape (time, nnodes, dim)
  ground_truth_positions = ground_truth_positions[:, :len(predictions)]
  predictions = torch.stack(predictions)
  ground_truth_positions = ground_truth_positions.permute(1, 0, 2)

  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss, velocitys, accs, edge_indexs


def rollout_N(
        simulator: learned_simulator.LearnedSimulator,
        position: torch.tensor,
        particle_types: torch.tensor,
        material_property: torch.tensor,
        n_particles_per_example: torch.tensor,
        nsteps: int,
        device: torch.device):
  """
  Rolls out a trajectory by applying the model in sequence.

  Args:
    simulator: Learned simulator.
    position: Positions of particles (timesteps, nparticles, ndims)
    particle_types: Particles types with shape (nparticles)
    material_property: Friction angle normalized by tan() with shape (nparticles)
    n_particles_per_example
    nsteps: Number of steps.
    device: torch device.
  """
  N = 10
  total_predictions = torch.zeros_like(position)
  total_predictions[:, :INPUT_SEQUENCE_LENGTH] = position[:, :INPUT_SEQUENCE_LENGTH]

  for offset in range(N):
      initial_indices = np.arange(0, INPUT_SEQUENCE_LENGTH) * N + offset
      initial_indices = initial_indices[initial_indices < position.shape[1]]  
      
      if len(initial_indices) < INPUT_SEQUENCE_LENGTH:
          continue
      
      current_positions = position[:, initial_indices]
      ground_truth_start = initial_indices[-1] + 1
      gt_indices = np.arange(ground_truth_start, position.shape[1], N)
      
      for step in range(len(gt_indices)):
        # Get next position with shape (nnodes, dim)
        next_position, acc = simulator.predict_positions(
            current_positions,
            nparticles_per_example=[n_particles_per_example],
            particle_types=particle_types,
            material_property=material_property
        )

        # Update kinematic particles from prescribed trajectory.
        kinematic_mask = (particle_types == KINEMATIC_PARTICLE_ID).clone().detach().to(device)
        next_position_ground_truth = position[:, gt_indices[step]].to(device)
        kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, current_positions.shape[-1])
        next_position = torch.where(
            kinematic_mask, next_position_ground_truth, next_position)
        # predictions.append(next_position)
        target_idx = gt_indices[step]
        if target_idx < total_predictions.shape[1]:
            total_predictions[:, target_idx] = next_position.detach().cpu()
        # Shift `current_positions`, removing the oldest position in the sequence
        # and appending the next position at the end.
        current_positions = torch.cat(
            [current_positions[:, 1:], next_position[:, None, :]], dim=1)

  # Predictions with shape (time, nnodes, dim)
  predictions = total_predictions[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  ground_truth_positions = position[:, N*INPUT_SEQUENCE_LENGTH:].permute(1, 0, 2)
  initial_positions = position[:, :N*INPUT_SEQUENCE_LENGTH]
  # import pdb; pdb.set_trace()
  # loss = (predictions - ground_truth_positions) ** 2
  # rmse
  loss = ((predictions - ground_truth_positions) ** 2)/((ground_truth_positions) ** 2)

  output_dict = {
      'initial_positions': initial_positions.permute(1, 0, 2).cpu().numpy(),
      'predicted_rollout': predictions.cpu().numpy(),
      'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
      'particle_types': particle_types.cpu().numpy(),
      'material_property': material_property.cpu().numpy() if material_property is not None else None
  }

  return output_dict, loss



def predict(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)
  simulator_control = _get_simulator_control(metadata, FLAGS.noise_std, FLAGS.noise_std, mode='test', device=device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  if os.path.exists(FLAGS.model_path_control + FLAGS.model_file):
    simulator_control.load(FLAGS.model_path_control + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")
  
  simulator.to(device)
  simulator.eval()
  simulator_control.to(device)
  simulator_control.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'
  # breakpoint()
  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz", path_dense=f"{FLAGS.data_path_dense}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  total_time = []
  count_nan_total = 0
  # import pdb; pdb.set_trace()
  # start_time = time.time()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      positions_dense = features[3].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      particle_type_dense = features[4].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      # Predict example rollout
      n_particles_per_example_dense = features[5]
      example_rollout, loss, _, _, rollout_time, count_nan = rollout(simulator,
                                      simulator_control,                               
                                      positions,
                                      positions_dense,
                                      particle_type,
                                      particle_type_dense,
                                      material_property,
                                      n_particles_per_example,
                                      n_particles_per_example_dense,
                                      nsteps,
                                      device)

      count_nan_total += count_nan
      total_time.append(rollout_time)
      print("rollout_time: {}".format(rollout_time))
      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))
      loss_txt = FLAGS.output_path + "loss_log.txt"
      with open(loss_txt, "a") as f:
        f.write(f"{loss.mean().item()}\n")
      time_txt = FLAGS.output_path + "time_log.txt"
      with open(time_txt, "a") as f:
        f.write(f"{rollout_time}\n")

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # breakpoint()
      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  # end_time = time.time()
  # execution_time = end_time - start_time
  # print(f"Total runing time: {execution_time:.4f} second")
  print("Nan times: {}".format(count_nan_total))
  print("Average_time: {}".format(sum(total_time)/len(total_time)))
  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))


def predict_acc_cosine(device: str):
  """Predict rollouts.

  Args:
    simulator: Trained simulator if not will undergo training.

  """
  # Read metadata
  metadata = reading_utils.read_metadata(FLAGS.data_path, "rollout")
  simulator = _get_simulator(metadata, FLAGS.noise_std, FLAGS.noise_std, device)

  # Load simulator
  if os.path.exists(FLAGS.model_path + FLAGS.model_file):
    simulator.load(FLAGS.model_path + FLAGS.model_file)
  else:
    raise Exception(f"Model does not exist at {FLAGS.model_path + FLAGS.model_file}")

  simulator.to(device)
  simulator.eval()

  # Output path
  if not os.path.exists(FLAGS.output_path):
    os.makedirs(FLAGS.output_path)

  # Use `valid`` set for eval mode if not use `test`
  split = 'test' if (FLAGS.mode == 'rollout' or (not os.path.isfile("{FLAGS.data_path}valid.npz"))) else 'valid'
  breakpoint()
  # Get dataset
  ds = data_loader.get_data_loader_by_trajectories(path=f"{FLAGS.data_path}{split}.npz", path_dense=f"{FLAGS.data_path_dense}{split}.npz")
  # See if our dataset has material property as feature
  if len(ds.dataset._data[0]) == 3:  # `ds` has (positions, particle_type, material_property)
    material_property_as_feature = True
  elif len(ds.dataset._data[0]) == 2:  # `ds` only has (positions, particle_type)
    material_property_as_feature = False
  else:
    raise NotImplementedError

  eval_loss = []
  # import pdb; pdb.set_trace()
  with torch.no_grad():
    for example_i, features in enumerate(ds):
      print(f"processing example number {example_i}")
      positions = features[0].to(device)
      positions_dense = features[3].to(device)
      if metadata['sequence_length'] is not None:
        # If `sequence_length` is predefined in metadata,
        nsteps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
      else:
        # If no predefined `sequence_length`, then get the sequence length
        sequence_length = positions.shape[1]
        nsteps = sequence_length - INPUT_SEQUENCE_LENGTH
      particle_type = features[1].to(device)
      particle_type_dense = features[4].to(device)
      if material_property_as_feature:
        material_property = features[2].to(device)
        n_particles_per_example = torch.tensor([int(features[3])], dtype=torch.int32).to(device)
      else:
        material_property = None
        n_particles_per_example = torch.tensor([int(features[2])], dtype=torch.int32).to(device)

      n_particles_per_example_dense = features[5]
      # Predict example rollout
      
      example_rollout, loss, velocity ,acc= rollout(simulator,
                                      positions,
                                      positions_dense,
                                      particle_type,
                                      particle_type_dense,
                                      material_property,
                                      n_particles_per_example,
                                      n_particles_per_example_dense,
                                      nsteps,
                                      device)
      
      # breakpoint()
      # plot loss-acc-curve
      cos_acc=[]
      for t in range(1, len(acc)):
        similarity = torch.nn.functional.cosine_similarity(
            acc[t], acc[t-1], dim=1
        )
        cos_acc.append(torch.mean(similarity).item())
      loss_per_timestep = loss.mean(dim=(1, 2)).cpu().detach()
      normalized_loss = (loss_per_timestep - loss_per_timestep.min()) / (loss_per_timestep.max() - loss_per_timestep.min())
      plt.figure(figsize=(10, 6))
      # plt.plot(range(loss_per_timestep.shape[0]), loss_per_timestep, label="Loss over Time Steps", color="blue")
      plt.plot((normalized_loss[1:]).numpy(), label='MSE Loss', color='blue', alpha=0.7)
      plt.plot(cos_acc, label='Acceleration Cosine Similarity', color='orange', alpha=0.7)
      plt.xlabel("Time Step")
      plt.ylabel("Loss")
      plt.title("Loss over Time Steps")
      plt.grid(True)
      plt.legend()
      path_name = FLAGS.output_path.strip('/').split('/')[-1]
      
      os.makedirs(f"loss_curve_acc_rollout_xy/{path_name}", exist_ok=True)           
      plt.savefig(f"loss_curve_acc_rollout_xy/{path_name}/{example_i}.png", dpi=300, bbox_inches="tight")
      plt.close()

      example_rollout['metadata'] = metadata
      print("Predicting example {} loss: {}".format(example_i, loss.mean()))
      eval_loss.append(torch.flatten(loss))

      # Save rollout in testing
      if FLAGS.mode == 'rollout':
        example_rollout['metadata'] = metadata
        example_rollout['loss'] = loss.mean()
        filename = f'{FLAGS.output_filename}_ex{example_i}.pkl'
        filename = os.path.join(FLAGS.output_path, filename)
        with open(filename, 'wb') as f:
          pickle.dump(example_rollout, f)
      
      # correlation
      # breakpoint()
      loss_seq = normalized_loss[1:].detach().cpu()
      cos_seq = torch.tensor(cos_acc, dtype=loss_seq.dtype)

      length = len(loss_seq)
      window_size = 10
      correlations = []
      mean_points = []
      for start in range(0, length, window_size):
          end = min(start + window_size, length)

          loss_chunk = loss_seq[start:end]
          cos_chunk = cos_seq[start:end]
          # if len(loss_chunk) >= 2:
          #     corr = torch.corrcoef(torch.stack([loss_chunk, cos_chunk]))[0, 1].item()
          #     correlations.append(corr)
          if len(loss_chunk) >= 1:
            mean_loss = loss_chunk.mean().item()
            mean_cos = cos_chunk.mean().item()
            mean_points.append((mean_cos, mean_loss))

      x_vals, y_vals = zip(*mean_points)
      plt.figure(figsize=(8, 6))
      plt.scatter(x_vals, y_vals, color='blue')
      plt.xlabel('Mean Cosine Accuracy (per window)')
      plt.ylabel('Mean Normalized Loss (per window)')
      plt.title('Scatter Plot of Cosine Accuracy vs Normalized Loss')
      plt.grid(True)
      plt.savefig(f"loss_curve_acc_rollout_xy/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")
      plt.close()
      # x_ticks = [min((i+1)*window_size, length) for i in range(len(correlations))]
      # plt.figure(figsize=(10, 5))
      # plt.plot(x_ticks, correlations, marker='o')
      # plt.xlabel('Step')
      # plt.ylabel('Correlation')
      # plt.title('Windowed Correlation (including last short window)')
      # plt.grid(True)
      # plt.tight_layout()          
      # plt.savefig(f"loss_curve_acc/{path_name}/{example_i}_correlations.png", dpi=300, bbox_inches="tight")


      # correlation = torch.corrcoef(torch.stack([
      #     normalized_loss[1:].detach(), 
      #     torch.tensor(cos_acc)
      # ]))[0, 1].item()

  print("Mean loss on rollout prediction: {}".format(
      torch.mean(torch.cat(eval_loss))))

def optimizer_to(optim, device):
  for param in optim.state.values():
    # Not sure there are any global tensors in the state dict
    if isinstance(param, torch.Tensor):
      param.data = param.data.to(device)
      if param._grad is not None:
        param._grad.data = param._grad.data.to(device)
    elif isinstance(param, dict):
      for subparam in param.values():
        if isinstance(subparam, torch.Tensor):
          subparam.data = subparam.data.to(device)
          if subparam._grad is not None:
            subparam._grad.data = subparam._grad.data.to(device)

def acceleration_loss(pred_acc, target_acc, non_kinematic_mask):
  """
  Compute the loss between predicted and target accelerations.

  Args:
    pred_acc: Predicted accelerations.
    target_acc: Target accelerations.
    non_kinematic_mask: Mask for kinematic particles.
  """
  loss = (pred_acc - target_acc) ** 2
  loss = loss.sum(dim=-1)
  num_non_kinematic = non_kinematic_mask.sum()
  loss = torch.where(non_kinematic_mask.bool(),
                    loss, torch.zeros_like(loss))
  loss = loss.sum() / num_non_kinematic
  return loss

def save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer,
                                train_loss, valid_loss, train_loss_hist, valid_loss_hist):
  """Save model state
  
  Args:
    rank: local rank
    device: torch device type
    simulator: Trained simulator if not will undergo training.
    flags: flags
    step: step
    epoch: epoch
    optimizer: optimizer
    train_loss: training loss at current step
    valid_loss: validation loss at current step
    train_loss_hist: training loss history at each epoch
    valid_loss_hist: validation loss history at each epoch
  """
  if rank == 0 or device == torch.device("cpu"):
      if device == torch.device("cpu"):
          simulator.save(flags["model_path"] + 'model-' + str(step) + '.pt')
      else:
          simulator.module.save(flags["model_path"] + 'model-' + str(step) + '.pt')

      train_state = dict(optimizer_state=optimizer.state_dict(),
                          global_train_state={
                            "step": step, 
                            "epoch": epoch,
                            "train_loss": train_loss,
                            "valid_loss": valid_loss
                            },
                          loss_history={"train": train_loss_hist, "valid": valid_loss_hist}
                          )
      torch.save(train_state, f'{flags["model_path"]}train_state-{step}.pt')
      
def train(rank, flags, world_size, device):
  """Train the model.

  Args:
    rank: local rank
    world_size: total number of ranks
    device: torch device type
  """
  if device == torch.device("cuda"):
    distribute.setup(rank, world_size, device)
    device_id = rank
  else:
    device_id = device

  # import pdb; pdb.set_trace()
  # Read metadata
  metadata = reading_utils.read_metadata(flags["data_path"], "train")

  # Get simulator and optimizer
  if device == torch.device("cuda"):
    serial_simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], rank)
    simulator = DDP(serial_simulator.to(rank), device_ids=[rank], output_device=rank)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"]*world_size)
  else:
    simulator = _get_simulator(metadata, flags["noise_std"], flags["noise_std"], device)
    optimizer = torch.optim.Adam(simulator.parameters(), lr=flags["lr_init"] * world_size)

  # Initialize training state
  step = 0
  epoch = 0
  steps_per_epoch = 0

  valid_loss = None
  epoch_train_loss = 0
  epoch_valid_loss = None

  train_loss_hist = []
  valid_loss_hist = []

  # If model_path does exist and model_file and train_state_file exist continue training.
  if flags["model_file"] is not None:

    if flags["model_file"] == "latest" and flags["train_state_file"] == "latest":
      # find the latest model, assumes model and train_state files are in step.
      fnames = glob.glob(f'{flags["model_path"]}*model*pt')
      max_model_number = 0
      expr = re.compile(".*model-(\d+).pt")
      for fname in fnames:
        model_num = int(expr.search(fname).groups()[0])
        if model_num > max_model_number:
          max_model_number = model_num
      # reset names to point to the latest.
      flags["model_file"] = f"model-{max_model_number}.pt"
      flags["train_state_file"] = f"train_state-{max_model_number}.pt"

    if os.path.exists(flags["model_path"] + flags["model_file"]) and os.path.exists(flags["model_path"] + flags["train_state_file"]):
      # load model
      if device == torch.device("cuda"):
        simulator.module.load(flags["model_path"] + flags["model_file"])
      else:
        simulator.load(flags["model_path"] + flags["model_file"])

      # load train state
      train_state = torch.load(flags["model_path"] + flags["train_state_file"])
      
      # set optimizer state
      optimizer = torch.optim.Adam(
        simulator.module.parameters() if device == torch.device("cuda") else simulator.parameters())
      optimizer.load_state_dict(train_state["optimizer_state"])
      optimizer_to(optimizer, device_id)
      
      # set global train state
      step = train_state["global_train_state"]["step"]
      epoch = train_state["global_train_state"]["epoch"]
      train_loss_hist = train_state["loss_history"]["train"]
      valid_loss_hist = train_state["loss_history"]["valid"]

    else:
      msg = f'Specified model_file {flags["model_path"] + flags["model_file"]} and train_state_file {flags["model_path"] + flags["train_state_file"]} not found.'
      raise FileNotFoundError(msg)

  simulator.train()
  simulator.to(device_id)

  # Get data loader
  get_data_loader = (
    distribute.get_data_distributed_dataloader_by_samples
    if device == torch.device("cuda")
    else data_loader.get_data_loader_by_samples
  )

  # Load training data
  # import pdb; pdb.set_trace()
  # train_npz = f'{flags["data_path"]}train.npz'
  # valid_npz = f'{flags["data_path"]}valid.npz'
  dl = get_data_loader(
      path=f'{flags["data_path"]}train.npz',
      input_length_sequence=INPUT_SEQUENCE_LENGTH,
      batch_size=flags["batch_size"],
  )
  n_features = len(dl.dataset._data[0])

  # Load validation data
  if flags["validation_interval"] is not None:
      dl_valid = get_data_loader(
          path=f'{flags["data_path"]}valid.npz',
          input_length_sequence=INPUT_SEQUENCE_LENGTH,
          batch_size=flags["batch_size"],
      )
      if len(dl_valid.dataset._data[0]) != n_features:
          raise ValueError(
              f"`n_features` of `valid.npz` and `train.npz` should be the same"
          )

  print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")

  try:
    while step < flags["ntraining_steps"]:
      if device == torch.device("cuda"):
        torch.distributed.barrier()

      for example in dl:  
        # import pdb; pdb.set_trace()
        steps_per_epoch += 1
        # ((position, particle_type, material_property, n_particles_per_example), labels) are in dl
        position = example[0][0].to(device_id)
        particle_type = example[0][1].to(device_id)
        if n_features == 3:  # if dl includes material_property
          material_property = example[0][2].to(device_id)
          n_particles_per_example = example[0][3].to(device_id)
        elif n_features == 2:
          n_particles_per_example = example[0][2].to(device_id)
        else:
          raise NotImplementedError
        labels = example[1].to(device_id)

        n_particles_per_example.to(device_id)
        labels.to(device_id)

        # TODO (jpv): Move noise addition to data_loader
        # Sample the noise to add to the inputs to the model during training.
        sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(position, noise_std_last_step=flags["noise_std"]).to(device_id)
        non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
        sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

        # Get the predictions and target accelerations
        device_or_rank = rank if device == torch.device("cuda") else device
        pred_acc, target_acc = (simulator.module.predict_accelerations if device == torch.device("cuda") else simulator.predict_accelerations)(
            next_positions=labels.to(device_or_rank),
            position_sequence_noise=sampled_noise.to(device_or_rank),
            position_sequence=position.to(device_or_rank),
            nparticles_per_example=n_particles_per_example.to(device_or_rank),
            particle_types=particle_type.to(device_or_rank),
            material_property=material_property.to(device_or_rank) if n_features == 3 else None
        )
        
        # Validation
        if flags["validation_interval"] is not None:
          sampled_valid_example = next(iter(dl_valid))
          if step > 0 and step % flags["validation_interval"] == 0:
              valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
              print(f"Validation loss at {step}: {valid_loss.item()}")

        # Calculate the loss and mask out loss on kinematic particles
        loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

        train_loss = loss.item()
        epoch_train_loss += train_loss

        # Computes the gradient of loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update learning rate
        lr_new = flags["lr_init"] * (flags["lr_decay"] ** (step/flags["lr_decay_steps"])) * world_size
        for param in optimizer.param_groups:
          param['lr'] = lr_new
     
        print(f'rank = {rank}, epoch = {epoch}, step = {step}/{flags["ntraining_steps"]}, loss = {train_loss}', flush=True)

        # Save model state
        if rank == 0 or device == torch.device("cpu"):
          if step % flags["nsave_steps"] == 0:
            save_model_and_train_state(rank, device, simulator, flags, step, epoch, 
                                       optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

        step += 1
        if step >= flags["ntraining_steps"]:
            break

      # Epoch level statistics
      # Training loss at epoch
      epoch_train_loss /= steps_per_epoch
      epoch_train_loss = torch.tensor([epoch_train_loss]).to(device_id)
      if device == torch.device("cuda"):
        torch.distributed.reduce(epoch_train_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        epoch_train_loss /= world_size

      train_loss_hist.append((epoch, epoch_train_loss.item()))

      # Validation loss at epoch
      if flags["validation_interval"] is not None:
        sampled_valid_example = next(iter(dl_valid))
        epoch_valid_loss = validation(
                simulator, sampled_valid_example, n_features, flags, rank, device_id)
        if device == torch.device("cuda"):
          torch.distributed.reduce(epoch_valid_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
          epoch_valid_loss /= world_size

        valid_loss_hist.append((epoch, epoch_valid_loss.item()))

      # Print epoch statistics
      if rank == 0 or device == torch.device("cpu"):
        print(f'Epoch {epoch}, training loss: {epoch_train_loss.item()}')
        if flags["validation_interval"] is not None:
          print(f'Epoch {epoch}, validation loss: {epoch_valid_loss.item()}')
      
      # Reset epoch training loss
      epoch_train_loss = 0
      if steps_per_epoch >= len(dl):
        epoch += 1
      steps_per_epoch = 0
      
      if step >= flags["ntraining_steps"]:
        break 
      
  except KeyboardInterrupt:
    pass

  # Save model state on keyboard interrupt
  save_model_and_train_state(rank, device, simulator, flags, step, epoch, optimizer, train_loss, valid_loss, train_loss_hist, valid_loss_hist)

  if torch.cuda.is_available():
    distribute.cleanup()


def _get_simulator(
        metadata: json,
        acc_noise_std: float,
        vel_noise_std: float,
        device: torch.device) -> learned_simulator.LearnedSimulator:
  """Instantiates the simulator.

  Args:
    metadata: JSON object with metadata.
    acc_noise_std: Acceleration noise std deviation.
    vel_noise_std: Velocity noise std deviation.
    device: PyTorch device 'cpu' or 'cuda'.
  """
  # import pdb; pdb.set_trace()
  # Normalization stats
  normalization_stats = {
      'acceleration': {
          'mean': torch.FloatTensor(metadata['acc_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['acc_std'])**2 +
                            acc_noise_std**2).to(device),
      },
      'velocity': {
          'mean': torch.FloatTensor(metadata['vel_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['vel_std'])**2 +
                            vel_noise_std**2).to(device),
      },
  }

  # Get necessary parameters for loading simulator.
  if "nnode_in" in metadata and "nedge_in" in metadata:
    nnode_in = metadata['nnode_in']
    nedge_in = metadata['nedge_in']
  else:
    # Given that there is no additional node feature (e.g., material_property) except for:
    # (position (dim), velocity (dim*6), particle_type (16)),
    nnode_in = 37 if metadata['dim'] == 3 else 30
    nedge_in = metadata['dim'] + 1

  # Init simulator.
  simulator = learned_simulator.LearnedSimulator(
      particle_dimensions=metadata['dim'],
      nnode_in=nnode_in,
      nedge_in=nedge_in,
      latent_dim=128,
      nmessage_passing_steps=10,
      nmlp_layers=2,
      mlp_hidden_dim=128,
      connectivity_radius=metadata['default_connectivity_radius'],
      boundaries=np.array(metadata['bounds']),
      normalization_stats=normalization_stats,
      nparticle_types=NUM_PARTICLE_TYPES,
      particle_type_embedding_size=16,
      boundary_clamp_limit=metadata["boundary_augment"] if "boundary_augment" in metadata else 1.0,
      device=device)

  return simulator

def _get_simulator_control(
        metadata: json,
        acc_noise_std: float,
        vel_noise_std: float,
        mode: str,
        device: torch.device) -> learned_simulator.LearnedSimulator:
  """Instantiates the simulator.

  Args:
    metadata: JSON object with metadata.
    acc_noise_std: Acceleration noise std deviation.
    vel_noise_std: Velocity noise std deviation.
    device: PyTorch device 'cpu' or 'cuda'.
  """
  # import pdb; pdb.set_trace()
  # Normalization stats
  normalization_stats = {
      'acceleration': {
          'mean': torch.FloatTensor(metadata['acc_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['acc_std'])**2 +
                            acc_noise_std**2).to(device),
      },
      'velocity': {
          'mean': torch.FloatTensor(metadata['vel_mean']).to(device),
          'std': torch.sqrt(torch.FloatTensor(metadata['vel_std'])**2 +
                            vel_noise_std**2).to(device),
      },
  }

  # Get necessary parameters for loading simulator.
  if "nnode_in" in metadata and "nedge_in" in metadata:
    nnode_in = metadata['nnode_in']
    nedge_in = metadata['nedge_in']
  else:
    # Given that there is no additional node feature (e.g., material_property) except for:
    # (position (dim), velocity (dim*6), particle_type (16)),
    # nnode_in = 37 if metadata['dim'] == 3 else 30 # original
    nnode_in = 40+16 if metadata['dim'] == 3 else 32+16 #40+16 #32 #46 #76 # + acc
    nedge_in = metadata['dim'] + 1

  # Init simulator.
  simulator_control = learned_simulator_control.LearnedSimulator_control(
      particle_dimensions=metadata['dim'],
      nnode_in=nnode_in,
      nedge_in=nedge_in,
      latent_dim=128,
      nmessage_passing_steps=10,
      nmlp_layers=2,
      mlp_hidden_dim=128, #128
      connectivity_radius=metadata['default_connectivity_radius'],
      boundaries=np.array(metadata['bounds']),
      normalization_stats=normalization_stats,
      nparticle_types=NUM_PARTICLE_TYPES,
      particle_type_embedding_size=16,
      boundary_clamp_limit=metadata["boundary_augment"] if "boundary_augment" in metadata else 1.0,
      channels=[32, 64, 128, 128], 
      nums_rb=2, 
      cin=192,
      dt=metadata['dt'],
      mode=mode,
      device=device)

  return simulator_control

def validation(
        simulator,
        example,
        n_features,
        flags,
        rank,
        device_id):

  position = example[0][0].to(device_id)
  particle_type = example[0][1].to(device_id)
  if n_features == 3:  # if dl includes material_property
    material_property = example[0][2].to(device_id)
    n_particles_per_example = example[0][3].to(device_id)
  elif n_features == 2:
    n_particles_per_example = example[0][2].to(device_id)
  else:
    raise NotImplementedError
  labels = example[1].to(device_id)

  # Sample the noise to add to the inputs.
  sampled_noise = noise_utils.get_random_walk_noise_for_position_sequence(
    position, noise_std_last_step=flags["noise_std"]).to(device_id)
  non_kinematic_mask = (particle_type != KINEMATIC_PARTICLE_ID).clone().detach().to(device_id)
  sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

  # Do evaluation for the validation data
  device_or_rank = rank if isinstance(device_id, int) else device_id
  # Select the appropriate prediction function
  predict_accelerations = simulator.module.predict_accelerations if isinstance(device_id, int) else simulator.predict_accelerations
  # Get the predictions and target accelerations
  with torch.no_grad():
      pred_acc, target_acc = predict_accelerations(
          next_positions=labels.to(device_or_rank),
          position_sequence_noise=sampled_noise.to(device_or_rank),
          position_sequence=position.to(device_or_rank),
          nparticles_per_example=n_particles_per_example.to(device_or_rank),
          particle_types=particle_type.to(device_or_rank),
          material_property=material_property.to(device_or_rank) if n_features == 3 else None
      )

  # Compute loss
  loss = acceleration_loss(pred_acc, target_acc, non_kinematic_mask)

  return loss


def main(_):
  """Train or evaluates the model.

  """
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  if device == torch.device('cuda'):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29502"

  myflags = reading_utils.flags_to_dict(FLAGS)

  if FLAGS.mode == 'train':
    # If model_path does not exist create new directory.
    if not os.path.exists(FLAGS.model_path):
      os.makedirs(FLAGS.model_path)
    # import pdb; pdb.set_trace()
    # Train on gpu 
    if device == torch.device('cuda'):
      available_gpus = torch.cuda.device_count()
      print(f"Available GPUs = {available_gpus}")

      # Set the number of GPUs based on availability and the specified number
      if FLAGS.n_gpus is None or FLAGS.n_gpus > available_gpus:
        world_size = available_gpus
        if FLAGS.n_gpus is not None:
          print(f"Warning: The number of GPUs specified ({FLAGS.n_gpus}) exceeds the available GPUs ({available_gpus})")
      else:
        world_size = FLAGS.n_gpus

      # Print the status of GPU usage
      print(f"Using {world_size}/{available_gpus} GPUs")

      # Spawn training to GPUs
      # import pdb; pdb.set_trace()
      # os.environ["MASTER_PORT"] = "29501"
      distribute.spawn_train(train, myflags, world_size, device)

    # Train on cpu  
    else:
      rank = None
      world_size = 1
      train(rank, myflags, world_size, device)

  elif FLAGS.mode in ['valid', 'rollout']:
    # Set device
    world_size = torch.cuda.device_count()
    if FLAGS.cuda_device_number is not None and torch.cuda.is_available():
      device = torch.device(f'cuda:{int(FLAGS.cuda_device_number)}')
    #test code
    print(f"device is {device} world size is {world_size}")
    predict(device)


if __name__ == '__main__':
  app.run(main)
