import taichi as ti
import imageio  
import numpy as np
import json
import os
from math import sqrt, cos, sin, pi

import torch


PARTICLE_TYPE={"Rigid":0, "Droplet": 1, "Boundary": 3, "Water": 5, "Sand": 6, "Goop": 7}

# Initialize Taichi
# ti.init(arch=ti.cpu, default_fp=ti.f32)  #ti.gpu
real = ti.f32
ti.init(default_fp=real, arch=ti.gpu, flatten_if=True, kernel_profiler=True)

SEED = 42
np.random.seed(SEED)

# Simulation parameters
# max_particles = 4000  
# min_particles = 3900 1200
n_grid = 128
dx = 1 / n_grid
dt = 2.5e-4
dim = 2
p_rho = 1
p_vol = (dx * 0.5) ** 2
p_mass = p_vol * p_rho
gravity = 9.8

lower_bound = 0.1
upper_bound = 0.9
bound = int(0.115 * n_grid)
E = 400

save_data = {}
velocities = []
accelerations = []
total_frames = 1000
steps_per_frame = 10

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

grid_mass_list = []

class ParticleReducer:
    def __init__(self, target_particles=256, random_state=42):

        self.n_clusters = target_particles
        self.kmeans = KMeans(
            n_clusters=target_particles,
            init='k-means++',
            n_init=10,
            random_state=random_state
        )
        self.scaler = StandardScaler()
        
    def reduce_system(self, positions, velocities):
        
        
        features = np.hstack([positions, velocities])
        scaled_features = self.scaler.fit_transform(features)
        
        
        self.kmeans.fit(scaled_features)
        
        
        cluster_centers_scaled = self.kmeans.cluster_centers_
        cluster_centers = self.scaler.inverse_transform(cluster_centers_scaled)
        dim = positions.shape[1]
        new_positions = cluster_centers[:, :dim]
        
        
        new_velocities = np.zeros_like(new_positions)
        for c in range(self.n_clusters):
            cluster_mask = self.kmeans.labels_ == c
            if np.any(cluster_mask):
                new_velocities[c] = np.mean(velocities[cluster_mask], axis=0)
                
        return new_positions, new_velocities

class ParticleReducerTaichi:
    def __init__(self, target_particles=256, max_iter=10, tol=1e-4, max_points=8192):
        self.K = target_particles
        self.max_iter = max_iter
        self.tol = tol
        self.max_points = max_points

        self.N = ti.field(dtype=ti.i32, shape=())  

        
        self.positions = ti.Vector.field(2, dtype=ti.f32, shape=(max_points,))
        self.centers = ti.Vector.field(2, dtype=ti.f32, shape=(self.K,))
        self.new_centers = ti.Vector.field(2, dtype=ti.f32, shape=(self.K,))
        self.counts = ti.field(dtype=ti.i32, shape=(self.K,))
        self.labels = ti.field(dtype=ti.i32, shape=(max_points,))

    def reduce_system(self, positions_np):
        N = positions_np.shape[0]
        assert N <= self.max_points

        self.N[None] = N
        self.positions.from_numpy(positions_np.astype(np.float32))

        
        rand_idx = np.random.choice(N, self.K, replace=False)
        self.centers.from_numpy(positions_np[rand_idx].astype(np.float32))

        for _ in range(self.max_iter):
            self.counts.fill(0)
            self.new_centers.fill(0.0)
            self.assign_labels()
            self.update_centers()

            

        return self.centers.to_numpy()

    @ti.kernel
    def assign_labels(self):
        for i in range(self.N[None]):
            min_dist = 1e10
            min_k = 0
            for k in range(self.K):
                d = (self.positions[i] - self.centers[k]).norm_sqr()
                if d < min_dist:
                    min_dist = d
                    min_k = k
            self.labels[i] = min_k

    @ti.kernel
    def update_centers(self):
        for i in range(self.N[None]):
            k = self.labels[i]
            ti.atomic_add(self.new_centers[k], self.positions[i])
            ti.atomic_add(self.counts[k], 1)

        for k in range(self.K):
            if self.counts[k] > 0:
                self.centers[k] = self.new_centers[k] / self.counts[k]


def initialize_fields(n_particles):
    # Field definitions
    x = ti.Vector.field(2, dtype=ti.f32,  shape=n_particles, needs_grad=True)
    v = ti.Vector.field(2, dtype=ti.f32,  shape=n_particles)
    C = ti.Matrix.field(2, 2, dtype=ti.f32,  shape=n_particles)
    J = ti.field(dtype=ti.f32,  shape=n_particles)

    grid_v = ti.Vector.field(2, dtype=ti.f32,  shape=(n_grid, n_grid))
    grid_m = ti.field(dtype=ti.f32,  shape=(n_grid, n_grid), needs_grad=True)
    
    return x, v, C, J, grid_v, grid_m

@ti.kernel
def substep_ori(x: ti.template(), v: ti.template(), C: ti.template(), J: ti.template(),
        grid_v: ti.template(), grid_m: ti.template(), n_particles: int):
    for i, j in grid_m:
        grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass
    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] /= grid_m[i, j]
        grid_v[i, j].y -= dt * gravity
        if i < bound and grid_v[i, j].x < 0:
            grid_v[i, j].x = 0
        if i > n_grid - bound and grid_v[i, j].x > 0:
            grid_v[i, j].x = 0
        if j < bound and grid_v[i, j].y < 0:
            grid_v[i, j].y = 0
        if j > n_grid - bound and grid_v[i, j].y > 0:
            grid_v[i, j].y = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v[p] = new_v
        x[p] += dt * v[p]
        if x[p].x < lower_bound or x[p].x > upper_bound:
            x[p].x = ti.max(lower_bound, ti.min(upper_bound, x[p].x))
        if x[p].y < lower_bound or x[p].y > upper_bound:
            x[p].y = ti.max(lower_bound, ti.min(upper_bound, x[p].y))
        J[p] *= 1 + dt * new_C.trace()
        C[p] = new_C

@ti.kernel
def substep_initial(x: ti.template(), v: ti.template(), C: ti.template(), J: ti.template(),
        grid_v: ti.template(), grid_m: ti.template(), x_gt:ti.template(), v_gt:ti.template()):
    for i, j in grid_m:
        grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass
    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] /= grid_m[i, j]
        grid_v[i, j].y -= dt * gravity
        if i < bound and grid_v[i, j].x < 0:
            grid_v[i, j].x = 0
        if i > n_grid - bound and grid_v[i, j].x > 0:
            grid_v[i, j].x = 0
        if j < bound and grid_v[i, j].y < 0:
            grid_v[i, j].y = 0
        if j > n_grid - bound and grid_v[i, j].y > 0:
            grid_v[i, j].y = 0
    for p in x_gt:
        Xp = x_gt[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v_gt[p] = new_v
        x_gt[p] += dt * v_gt[p]
        if x_gt[p].x < lower_bound or x_gt[p].x > upper_bound:
            x_gt[p].x = ti.max(lower_bound, ti.min(upper_bound, x_gt[p].x))
        if x_gt[p].y < lower_bound or x_gt[p].y > upper_bound:
            x_gt[p].y = ti.max(lower_bound, ti.min(upper_bound, x_gt[p].y))
        J[p] *= 1 + dt * new_C.trace()
        C[p] = new_C

@ti.kernel
def substep_mpm(x: ti.template(), v: ti.template(), C: ti.template(), J: ti.template(),
        grid_v: ti.template(), grid_m: ti.template()):
    for i, j in grid_m:
        grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass
    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] /= grid_m[i, j]
        grid_v[i, j].y -= dt * gravity
        if i < bound and grid_v[i, j].x < 0:
            grid_v[i, j].x = 0
        if i > n_grid - bound and grid_v[i, j].x > 0:
            grid_v[i, j].x = 0
        if j < bound and grid_v[i, j].y < 0:
            grid_v[i, j].y = 0
        if j > n_grid - bound and grid_v[i, j].y > 0:
            grid_v[i, j].y = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v[p] = new_v
        x[p] += dt * v[p]
        if x[p].x < lower_bound or x[p].x > upper_bound:
            x[p].x = ti.max(lower_bound, ti.min(upper_bound, x[p].x))
        if x[p].y < lower_bound or x[p].y > upper_bound:
            x[p].y = ti.max(lower_bound, ti.min(upper_bound, x[p].y))
        J[p] *= 1 + dt * new_C.trace()
        C[p] = new_C

@ti.kernel
def substep_mpm_doubleline(x: ti.template(), v: ti.template(), C: ti.template(), J: ti.template(),
        grid_v: ti.template(), grid_m: ti.template(), x_sparse: ti.template(), v_sparse: ti.template(), 
        C_sparse: ti.template(), J_sparse: ti.template()):
    for i, j in grid_m:
        grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass
    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] /= grid_m[i, j]
        grid_v[i, j].y -= dt * gravity
        if i < bound and grid_v[i, j].x < 0:
            grid_v[i, j].x = 0
        if i > n_grid - bound and grid_v[i, j].x > 0:
            grid_v[i, j].x = 0
        if j < bound and grid_v[i, j].y < 0:
            grid_v[i, j].y = 0
        if j > n_grid - bound and grid_v[i, j].y > 0:
            grid_v[i, j].y = 0
    # g2p-high resolution
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v[p] = new_v
        x[p] += dt * v[p]
        if x[p].x < lower_bound or x[p].x > upper_bound:
            x[p].x = ti.max(lower_bound, ti.min(upper_bound, x[p].x))
        if x[p].y < lower_bound or x[p].y > upper_bound:
            x[p].y = ti.max(lower_bound, ti.min(upper_bound, x[p].y))
        J[p] *= 1 + dt * new_C.trace()
        C[p] = new_C
    # g2p-low resolution
    for p in x_sparse:
        Xp = x_sparse[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v_sparse[p] = new_v
        x_sparse[p] += dt * v_sparse[p]
        if x_sparse[p].x < lower_bound or x_sparse[p].x > upper_bound:
            x_sparse[p].x = ti.max(lower_bound, ti.min(upper_bound, x_sparse[p].x))
        if x_sparse[p].y < lower_bound or x_sparse[p].y > upper_bound:
            x_sparse[p].y = ti.max(lower_bound, ti.min(upper_bound, x_sparse[p].y))
        J_sparse[p] *= 1 + dt * new_C.trace()
        C_sparse[p] = new_C

@ti.kernel
def substep(x: ti.template(), grid_m: ti.template()):
    for i, j in grid_m:
        # grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        # stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        # affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            # grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass

    

def initialize_particles(shape_type, n_particles):
    
    available_size = upper_bound - lower_bound  

    if shape_type == 0:  
        
        max_shape_size = available_size * 0.6
        width = np.random.uniform(0.1, max_shape_size)
        height = np.random.uniform(0.1, max_shape_size)
        
        
        min_x = lower_bound + width/2
        max_x = upper_bound - width/2
        min_y = lower_bound + height/2
        max_y = upper_bound - height/2
        
        center_x = np.random.uniform(min_x, max_x)
        center_y = np.random.uniform(min_y, max_y)
        
        
        x_np = np.random.rand(n_particles, 2).astype(np.float32) * [width, height] + [center_x - width/2, center_y - height/2]

    elif shape_type == 1:  
        max_radius = available_size * 0.3 
        radius = np.random.uniform(0.05, max_radius)
        
     
        min_x = lower_bound + radius
        max_x = upper_bound - radius
        min_y = lower_bound + radius
        max_y = upper_bound - radius
        
        center = np.array([
            np.random.uniform(min_x, max_x),
            np.random.uniform(min_y, max_y)
        ])
        
        
        angles = np.random.rand(n_particles) * 2 * np.pi
        rads = np.sqrt(np.random.rand(n_particles)) * radius
        x_np = np.column_stack([
            center[0] + rads * np.cos(angles),
            center[1] + rads * np.sin(angles)
        ])

    elif shape_type == 2:  
        max_base_size = available_size * 0.4 
        base_size = np.random.uniform(0.1, max_base_size)
        
        
        safe_margin = base_size * 0.6  
        min_x = lower_bound + safe_margin
        max_x = upper_bound - safe_margin
        min_y = lower_bound + safe_margin
        max_y = upper_bound - safe_margin
        
        center = np.array([
            np.random.uniform(min_x, max_x),
            np.random.uniform(min_y, max_y)
        ])
        
        angle = np.random.uniform(0, 2*np.pi)
        rot_matrix = np.array([
            [np.cos(angle), -np.sin(angle)],
            [np.sin(angle), np.cos(angle)]
        ])
        
        vertices = np.array([
            [-base_size/2, -base_size/2],  
            [base_size/2, -base_size/2],   
            [0, base_size/2]               
        ])
        
        vertices = vertices @ rot_matrix.T + center
        
        u = np.random.rand(n_particles, 2).astype(np.float32)
        mask = u.sum(1) > 1
        u[mask] = 1 - u[mask]
        x_np = (u[:, 0:1] * vertices[0] + 
                u[:, 1:2] * vertices[1] + 
                (1 - u.sum(1, keepdims=True)) * vertices[2])
    
    x_np = np.clip(x_np, lower_bound, upper_bound)
    
    return x_np

@ti.kernel
def initialize_fields_values(C: ti.template(), J: ti.template()):
    J.fill(1)
    C.fill(0)

@ti.kernel
def initialize_fields_values_2(v: ti.template(), C: ti.template(), J: ti.template()):
    
    v.fill(0)
    J.fill(1)
    C.fill(0)

def run_simulation_mpm(next_position, next_velocity, next_position_dense, output_dir="thershold_gns_mpm"):

    
    # n_particles = np.random.randint(min_particles, max_particles + 1)
    # second_positions = data[0][1]
    # velocity = second_positions - first_positions
    # breakpoint()
    next_velocity = next_velocity/(dt* steps_per_frame)
    n_particles = int(next_position.shape[0])
    n_particles_dense = int(next_position_dense.shape[0])
    # reducer = ParticleReducerTaichi(target_particles=n_particles)
    # print(first_positions.shape[0], n_particles)
    # reducer = ParticleReducer(target_particles=n_particles)
    x, v, C, J, grid_v, grid_m = initialize_fields(n_particles)
    x_gt, v_gt, C_gt, J_gt, grid_v_gt, grid_m_gt = initialize_fields(n_particles_dense)
    # x_np = initialize_particles(shape_type, n_particles)
    x.from_numpy(next_position.cpu().numpy())
    v.from_numpy(next_velocity.cpu().numpy())
    initialize_fields_values(C, J)
    x_gt.from_numpy(next_position_dense.cpu().numpy())
    initialize_fields_values_2(v_gt, C_gt, J_gt)
    
    substep_initial(x, v, C, J, grid_v, grid_m, x_gt, v_gt)
    positions = []
    velocities = []
    positions_sparse = []
    # breakpoint()
    a = 0
    for frame in range(50):
        for s in range(steps_per_frame):
            substep_mpm_doubleline(x_gt, v_gt, C_gt, J_gt, grid_v_gt, grid_m_gt, x, v, C, J)
            
        pos = x_gt.to_numpy()
        positions.append(pos)
        vel = v_gt.to_numpy()
        vel = vel * dt* steps_per_frame
        velocities.append(vel)
        # breakpoint()
        if a == 0:
            a = 1
            x_sparse = sample_points(torch.tensor(pos), n_particles)
            x.from_numpy(x_sparse.cpu().numpy())

        pos_sparse = x.to_numpy()
        positions_sparse.append(pos_sparse)
        # breakpoint()
        # new_pos = reducer.reduce_system(pos)
    sim_id = 2
    # breakpoint()
    save_video(1000, positions, output_dir)
    save_video(1001, positions_sparse, output_dir)
    # positions = np.array(positions)
    # velocities = np.array(velocities)
    return positions, velocities, positions_sparse

import time
def run_simulation_mpm_thershold(next_position, next_velocity, step=50, output_dir="Gns_mpm_test"):

    
    # n_particles = np.random.randint(min_particles, max_particles + 1)
    # second_positions = data[0][1]
    # velocity = second_positions - first_positions
    # breakpoint()
    start_time = time.time()
    next_velocity = next_velocity/(dt* steps_per_frame)
    n_particles = int(next_position.shape[0])
    # n_particles_dense = int(next_position_dense.shape[0])
    # reducer = ParticleReducerTaichi(target_particles=n_particles)
    # print(first_positions.shape[0], n_particles)
    # reducer = ParticleReducer(target_particles=n_particles)
    x, v, C, J, grid_v, grid_m = initialize_fields(n_particles)
    # x_gt, v_gt, C_gt, J_gt, grid_v_gt, grid_m_gt = initialize_fields(n_particles_dense)
    # x_np = initialize_particles(shape_type, n_particles)
    # x.from_torch(next_position)
    # v.from_torch(next_velocity)
    x.from_numpy(next_position.cpu().numpy())
    v.from_numpy(next_velocity.cpu().numpy())

    initialize_fields_values(C, J)
    end_time = time.time()
    load_data_time = end_time - start_time
    # x_gt.from_numpy(next_position_dense.cpu().numpy())
    # initialize_fields_values_2(v_gt, C_gt, J_gt)
    
    # substep_initial(x, v, C, J, grid_v, grid_m, x_gt, v_gt)
    positions = []
    velocities = []
    all_time = []
    # start_time = time.time()
    for frame in range(step):
        start_time = time.time()
        for s in range(steps_per_frame):
            substep_ori(x, v, C, J, grid_v, grid_m, n_particles)
        pos = x.to_numpy()
        positions.append(pos)
        vel = v.to_numpy()
        vel = vel * dt* steps_per_frame
        velocities.append(vel)
        end_time = time.time()
        all_time.append(end_time-start_time)
        # breakpoint()
        # new_pos = reducer.reduce_system(pos)
    # sim_id = 2
    # breakpoint()
    # save_video(sim_id, positions, output_dir)
    # positions = np.array(positions)
    # velocities = np.array(velocities)
    # ti.profiler.print_kernel_profiler_info()
    minus_time = load_data_time + all_time[0]
    return positions, velocities#, minus_time



def run_simulation(data, data_gt, output_dir="Water_drop_simulations"):
    # breakpoint()
    
    # n_particles = np.random.randint(min_particles, max_particles + 1)
    first_positions = data[0]
    # second_positions = data[0][1]
    # velocity = second_positions - first_positions
    n_particles = int(first_positions.shape[0])
    n_particles_dense = int(data_gt.shape[1])
    # print(first_positions.shape[0], n_particles)
    
    # Randomly choose initialization shape
    shape_type = np.random.randint(0, 3)  # 0=rect, 1=circle, 2=triangle
    x, v, C, J, grid_v, grid_m = initialize_fields(n_particles)
    x_gt, v_gt, C_gt, J_gt, grid_v_gt, grid_m_gt = initialize_fields(n_particles_dense)
    x_dense, v_dense, C_dense, J_dense, grid_v_dense, grid_m_dense = initialize_fields(n_particles_dense)
    # x_np = initialize_particles(shape_type, n_particles)
    # x.from_numpy(first_positions)
    # v.from_numpy(velocity)
    initialize_fields_values(C, J)
    initialize_fields_values(C_gt, J_gt)
    initialize_fields_values(C_dense, J_dense)
    
    # Data storage
    # positions = []
    # velocities = []
    # accelerations = []
    # particle_types = PARTICLE_TYPE['Water'] * np.ones(n_particles)
    # prev_vel = v.to_numpy()
    iou_list = []
    grid_ms = []
    grid_m_gts = []
    # breakpoint()
    for frame in range(len(data)):
        if data[frame].shape[0] == n_particles:
            x.from_numpy(data[frame].cpu().numpy())
            x_gt.from_numpy(data_gt[frame].cpu().numpy())
            substep(x, grid_m)
            substep(x_gt, grid_m_gt)
            
            grid_m_tensor = torch.tensor(grid_m.to_numpy())
            grid_m_tensor = grid_m_tensor / grid_m_tensor.sum()
            grid_m_gt_tensor = torch.tensor(grid_m_gt.to_numpy())
            grid_m_gt_tensor = grid_m_gt_tensor / grid_m_gt_tensor.sum()
            grid_ms.append(grid_m_tensor)
            grid_m_gts.append(grid_m_gt_tensor)
        
        elif data[frame].shape[0] == n_particles_dense:
            x_dense.from_numpy(data[frame].cpu().numpy())
            x_gt.from_numpy(data_gt[frame].cpu().numpy())
            substep(x_dense, grid_m_dense)
            substep(x_gt, grid_m_gt)
            
            grid_m_tensor = torch.tensor(grid_m_dense.to_numpy())
            grid_m_tensor = grid_m_tensor / grid_m_tensor.sum()
            grid_m_gt_tensor = torch.tensor(grid_m_gt.to_numpy())
            grid_m_gt_tensor = grid_m_gt_tensor / grid_m_gt_tensor.sum()
            grid_ms.append(grid_m_tensor)
            grid_m_gts.append(grid_m_gt_tensor)
        
        else:
            print('particle number warning')
        
        # print(grid_m_tensor)
        # iou = iou_kernel(grid_m, grid_m_gt)
        # print(iou)
        # iou_list.append(iou)
        # print(grid_m.to_numpy())
    grid_ms = torch.stack(grid_ms)
    grid_m_gts = torch.stack(grid_m_gts)
    # Save data in requested format
    # Save data in requested format
    # trajectory = np.empty(2, dtype=object)
    # trajectory[0] = np.array(positions)
    # trajectory[1] = particle_types
    # save_data[f'simulation_trajectory_{sim_id}'] = trajectory
      
    # Optional: Save video
    # if sim_id < 10:  # Save videos for first 10 sims only to save space
    #     save_video(sim_id, positions, output_dir)

    return grid_ms, grid_m_gts


y = ti.field(dtype=ti.f32, shape=(), needs_grad=True)
grid = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def compute_diff(grid_m: ti.template(), grid_m_gt: ti.template(), temp: ti.template()):
    for i, j in grid_m:
        diff = grid_m[i, j] - grid_m_gt[i, j]
        temp[i, j] = diff * diff  
    
@ti.kernel
def compute_y(temp: ti.template()):
    for I in ti.ndrange(temp.shape[0], temp.shape[1]):
        ti.atomic_add(y[None], temp[I])

scalar = lambda: ti.field(dtype=ti.f32)
target = ti.field(dtype=ti.f32, shape=(n_grid, n_grid), needs_grad=True)
smoke = ti.field(dtype=ti.f32, shape=(n_grid, n_grid), needs_grad=True)
loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

@ti.kernel
def compute_loss():
    for i in range(n_grid):
        for j in range(n_grid):
            loss[None] += (target[i, j] - smoke[i, j])**2

@ti.kernel
def init(grid_m: ti.template(), grid_m_gt: ti.template()):
    for i, j in ti.ndrange(n_grid, n_grid):
        target[i, j] = grid_m[i, j]
        smoke[i, j] = grid_m_gt[i,j]

# @ti.kernel
# def compute_y(grid_m: ti.template(), grid_m_gt:ti.template()):
#     y[None] = 0.0
#     for i, j in grid_m:
#         diff = grid_m[i, j] - grid_m_gt[i, j]
#         y[None] += diff * diff


def run_simulation_single(data, data_gt, output_dir="Water_drop_simulations"):

    breakpoint()
    first_positions = data
    # second_positions = data[0][1]
    # velocity = second_positions - first_positions
    n_particles = int(first_positions.shape[0])
    # print(first_positions.shape[0], n_particles)
    
    # Randomly choose initialization shape
    shape_type = np.random.randint(0, 3)  # 0=rect, 1=circle, 2=triangle
    x, v, C, J, grid_v, grid_m = initialize_fields(n_particles)
    x_gt, v_gt, C_gt, J_gt, grid_v_gt, grid_m_gt = initialize_fields(n_particles)
    # x_np = initialize_particles(shape_type, n_particles)
    # x.from_numpy(first_positions)
    # v.from_numpy(velocity)
    initialize_fields_values(C, J)
    initialize_fields_values(C_gt, J_gt)
    
    iou_list = []
    grid_ms = []
    grid_m_gts = []
    
    # x.from_numpy(data.detach().cpu().numpy())
    # x_gt.from_numpy(data_gt.detach().cpu().numpy())
    x.data = data
    x_gt.data = data_gt

    substep(x_gt, grid_m_gt)
    substep(x, grid_m)
    breakpoint()
    init(grid_m, grid_m_gt)
    with ti.ad.Tape(loss):
        compute_loss()

    # grid_m_tensor = torch.tensor(grid_m.to_numpy())
    # grid_m_gt_tensor = torch.tensor(grid_m_gt.to_numpy())
    grid_m_tensor = torch.tensor(grid_m.to_numpy(), dtype=torch.float32, device=data.device, requires_grad=True)
    grid_m_gt_tensor = torch.tensor(grid_m_gt.to_numpy(), dtype=torch.float32, device=data.device, requires_grad=True) 
    grid_ms.append(grid_m_tensor)
    grid_m_gts.append(grid_m_gt_tensor)

    grid_ms = torch.stack(grid_ms)
    grid_m_gts = torch.stack(grid_m_gts)

    return loss.to_torch()

def save_video(sim_id, positions, output_dir):
    output_path = os.path.join(output_dir,f"sim_{sim_id}.mp4")
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 96
    writer = imageio.get_writer(output_path, fps=fps)
    positions = np.array(positions)
    for idx in range(positions.shape[0]):
        gui.clear(0xFFFFFF)
        gui.circles(positions[idx], radius=1.5, color=0x068587)
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()

if __name__ == "__main__":
    # num_simulations = 12
    output_dir = 'Water_drop_xl_simulations_k-means_0.125'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    
    Data_name = 'Water_drop_xl_simulations'
    split = 'train.npz'
    npz_filename = f''
    data = np.load(npz_filename, allow_pickle=True)
    num_simulations = len(data)

    idx = 0
    for idx in range(num_simulations):
    # while idx < num_simulations:
        if split == 'test.npz':
            trajectory_id = 1000 + idx
        else:
            trajectory_id = idx
        print(f"Running simulation {idx}/{num_simulations}")
        iou_list = run_simulation(data[f'simulation_trajectory_{trajectory_id}'],data[f'simulation_trajectory_{trajectory_id}'],output_dir)
        if velocities is None:
            with open(os.path.join(output_dir, 'trian_no.txt'), "a") as f:
                f.write(f"{idx},")
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        # all_velocities.append(velocities)
        # all_accelerations.append(accelerations)


    # vel_stats = {
    #     "mean": np.mean(np.array([np.mean(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0)),
    #     "std": np.std(np.array([np.std(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0))
    # }
    
    # acc_stats = {
    #     "mean": np.mean(np.array([np.mean(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0)),
    #     "std": np.mean(np.array([np.std(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0))
    # }
    # metadata = {
    #     "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
    #     "sequence_length": total_frames ,
    #     "default_connectivity_radius": 0.01,
    #     "dim": dim,
    #     "dt": dt*steps_per_frame,
    #     "vel_mean": vel_stats['mean'].tolist(),
    #     "vel_std": vel_stats['std'].tolist(),
    #     "acc_mean": acc_stats['mean'].tolist(),
    #     "acc_std": acc_stats['std'].tolist(),
    #     "obstacles_configs": [] # [[[,],[,]],], 
    # }
    
    # np.savez(os.path.join(output_dir, split), **save_data)
    # with open(os.path.join(output_dir, f'train.json'), 'w') as f:
    #     json.dump(metadata, f, indent=2)
    # print("All simulations completed!")