import taichi as ti
import numpy as np
import imageio
import json
import os
from math import sqrt, cos, sin, pi

@ti.data_oriented
class MPM_Simulator:
    def __init__(
                self,
                max_particles,
                min_particles,
                n_grid,
                dt,
                gravity,
                dim,
                lower_bound,
                upper_bound,
                bound,
                material_type='water',
                obstacle=None,
                n_obstacle_particles=300,
                line_thickness=0.008):
        
        
        # Common simulation parameters
        self.max_particles = max_particles
        self.min_particles = min_particles
        self.n_particles = None
        self.n_grid = n_grid
        self.dx, self.inv_dx = 1 / self.n_grid, float(self.n_grid)
        self.dt = dt
        if dim == 3:
            self.gravity = ti.Vector([0, -gravity, 0])
        else:
          self.gravity = ti.Vector([0, -gravity])
        self.dim = dim
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.bound = int(bound * self.n_grid)
        
        self.force_x = np.random.uniform(-20, 20)
        self.force_y = np.random.uniform(-20, 20)
        self.force_z = np.random.uniform(-20, 20) if dim == 3 else 0

        # Material specific parameters
        self.material_type = material_type.lower()
        if self.material_type == 'water':
            self._init_water_params()
        elif self.material_type == 'sand':
            self._init_sand_params()
        else:
            raise ValueError("Material type must be 'water' or 'sand'")
        
        if obstacle is not None:
            self.n_obstacle_particles = n_obstacle_particles
            self.line_thickness = line_thickness
            lines_np = None
            obstacle_positions = []
            if isinstance(obstacle, list):
                lines_np =  np.array(obstacle, dtype=np.float32) 
                if lines_np.shape[-1] != 4:
                    lines_np = lines_np.reshape(-1, 4)
                self.line_lengths = []
                for i in range(lines_np.shape[0]):
                    line_len = sqrt((lines_np[i,2]-lines_np[i,0])**2 + (lines_np[i,3]-lines_np[i,1])**2)
                    self.line_lengths.append(line_len)
                total_length = sum(self.line_lengths)
                particles_per_line = [max(1, int(n_obstacle_particles * length / total_length)) for length in self.line_lengths]
                diff = n_obstacle_particles - sum(particles_per_line)
                if diff != 0:
                    particles_per_line[0] += diff
                
                for i, line in enumerate(lines_np):
                    start_x, start_y, end_x, end_y = line
                    t = np.linspace(0, 1, particles_per_line[i])
                    obstacle_pos = np.outer(1 - t, [start_x, start_y]) + np.outer(t, [end_x, end_y])
                    obstacle_positions.append(obstacle_pos)
                obstacle_positions = np.concatenate(obstacle_positions)
            elif isinstance(obstacle, np.ndarray):
                lines_np = obstacle
                if lines_np.shape[-1] != 4:
                    lines_np = lines_np.reshape(-1, 4)
                self.line_lengths = []
                for i in range(lines_np.shape[0]):
                    line_len = sqrt((lines_np[i,2]-lines_np[i,0])**2 + (lines_np[i,3]-lines_np[i,1])**2)
                    self.line_lengths.append(line_len)
                total_length = sum(self.line_lengths)
                particles_per_line = [max(1, int(n_obstacle_particles * length / total_length)) for length in self.line_lengths]
                diff = n_obstacle_particles - sum(particles_per_line)
                if diff != 0:
                    particles_per_line[0] += diff
                
                for i, line in enumerate(lines_np):
                    start_x, start_y, end_x, end_y = line
                    t = np.linspace(0, 1, particles_per_line[i])
                    obstacle_pos = np.outer(1 - t, [start_x, start_y]) + np.outer(t, [end_x, end_y])
                    obstacle_positions.append(obstacle_pos)
                obstacle_positions = np.concatenate(obstacle_positions)
            else:
                lines_np, self.line_lengths = self.init_obstacles()
                total_length = sum(self.line_lengths)
                particles_per_line = [max(1, int(n_obstacle_particles * length / total_length)) for length in self.line_lengths]
                diff = n_obstacle_particles - sum(particles_per_line)
                if diff != 0:
                    particles_per_line[0] += diff
                for i, line in enumerate(lines_np):
                    start_x, start_y, end_x, end_y = line
                    t = np.linspace(0, 1, particles_per_line[i])
                    obstacle_pos = np.outer(1 - t, [start_x, start_y]) + np.outer(t, [end_x, end_y])
                    obstacle_positions.append(obstacle_pos)
                obstacle_positions = np.concatenate(obstacle_positions)
            self.num_lines = ti.field(dtype=ti.i32, shape=())
            self.num_lines[None] = lines_np.shape[0]
            self.lines = ti.Vector.field(4, dtype=ti.f32, shape=self.num_lines[None])
            self.lines.from_numpy(lines_np)
            self.ob_x = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.n_obstacle_particles)
            self.ob_x.from_numpy(obstacle_positions)
            # self._update_lines_field(lines_np)
        else:
            self.line_thickness = line_thickness
            self.num_lines = ti.field(dtype=ti.i32, shape=())
            self.num_lines[None] = 0
            self.lines = ti.Vector.field(4, dtype=ti.f32, shape=1)

    # @ti.kernel
    # def _update_lines_field(self, lines_np: ti.types.ndarray()):
    #     num_lines = self.num_lines[None]
    #     self.lines.from_numpy(lines_np)
        # for i in range(num_lines):
        #     self.lines[i] = ti.Vector([lines_np[i,0], lines_np[i,1], lines_np[i,2], lines_np[i,3]])


    def init_particles(self, x=None, v=None, random_v=False):
        if x is not None and v is not None:
            self.n_particles = x.shape[0]
            self.initialize_fields()
            self.x.from_numpy(x)
            self.v.from_numpy(v)
            return
        elif x is not None:
            self.n_particles = x.shape[0]
            self.initialize_fields()
            self.x.from_numpy(x)
            return
        
        self.n_particles = np.random.randint(self.min_particles, self.max_particles + 1)
        self.initialize_fields()
        shape_type = np.random.randint(0, 3)
        self.init_particles_shape(shape_type, random_v)

        
    def init_particles_shape(self, shape_type, random_v):
        available_size = self.upper_bound - self.lower_bound  

        if shape_type == 0:  # Box
            max_shape_size = available_size * 0.8
            if self.dim == 3:
                width = np.random.uniform(0.1, max_shape_size)
                height = np.random.uniform(0.1, max_shape_size)
                depth = np.random.uniform(0.1, max_shape_size)
                
                min_x = self.lower_bound + width/2
                max_x = self.upper_bound - width/2
                min_y = self.lower_bound + height/2
                max_y = self.upper_bound - height/2
                min_z = self.lower_bound + depth/2
                max_z = self.upper_bound - depth/2
                
                center_x = np.random.uniform(min_x, max_x)
                center_y = np.random.uniform(min_y, max_y)
                center_z = np.random.uniform(min_z, max_z)
                
                x_np = np.random.rand(self.n_particles, 3).astype(np.float32) * [width, height, depth] + \
                      [center_x - width/2, center_y - height/2, center_z - depth/2]
            else:
                width = np.random.uniform(0.1, max_shape_size)
                height = np.random.uniform(0.1, max_shape_size)
                
                min_x = self.lower_bound + width/2
                max_x = self.upper_bound - width/2
                min_y = self.lower_bound + height/2
                max_y = self.upper_bound - height/2
                
                center_x = np.random.uniform(min_x, max_x)
                center_y = np.random.uniform(min_y, max_y)
                
                x_np = np.random.rand(self.n_particles, 2).astype(np.float32) * [width, height] + \
                      [center_x - width/2, center_y - height/2]

        elif shape_type == 1:  # Sphere (3D) or Circle (2D)
            max_radius = available_size * 0.4 
            radius = np.random.uniform(0.05, max_radius)
            
            if self.dim == 3:
                min_x = self.lower_bound + radius
                max_x = self.upper_bound - radius
                min_y = self.lower_bound + radius
                max_y = self.upper_bound - radius
                min_z = self.lower_bound + radius
                max_z = self.upper_bound - radius
                
                center = np.array([
                    np.random.uniform(min_x, max_x),
                    np.random.uniform(min_y, max_y),
                    np.random.uniform(min_z, max_z)
                ])
                
                # Uniform sampling in sphere
                angles = np.random.rand(self.n_particles, 2)
                theta = angles[:,0] * 2 * np.pi
                phi = np.arccos(2 * angles[:,1] - 1)
                rads = np.random.rand(self.n_particles) ** (1/3) * radius
                
                x_np = np.column_stack([
                    center[0] + rads * np.sin(phi) * np.cos(theta),
                    center[1] + rads * np.sin(phi) * np.sin(theta),
                    center[2] + rads * np.cos(phi)
                ])
            else:
                min_x = self.lower_bound + radius
                max_x = self.upper_bound - radius
                min_y = self.lower_bound + radius
                max_y = self.upper_bound - radius
                
                center = np.array([
                    np.random.uniform(min_x, max_x),
                    np.random.uniform(min_y, max_y)
                ])
                
                angles = np.random.rand(self.n_particles) * 2 * np.pi
                rads = np.sqrt(np.random.rand(self.n_particles)) * radius
                x_np = np.column_stack([
                    center[0] + rads * np.cos(angles),
                    center[1] + rads * np.sin(angles)
                ])

        elif shape_type == 2:  # Tetrahedron (3D) or Triangle (2D)
            max_base_size = available_size * 0.6 
            base_size = np.random.uniform(0.1, max_base_size)
            
            safe_margin = base_size * 0.6  
            min_x = self.lower_bound + safe_margin
            max_x = self.upper_bound - safe_margin
            min_y = self.lower_bound + safe_margin
            max_y = self.upper_bound - safe_margin
            
            if self.dim == 3:
                min_z = self.lower_bound + safe_margin
                max_z = self.upper_bound - safe_margin
                
                center = np.array([
                    np.random.uniform(min_x, max_x),
                    np.random.uniform(min_y, max_y),
                    np.random.uniform(min_z, max_z)
                ])
                
                angle_x = np.random.uniform(0, 2*np.pi)
                angle_y = np.random.uniform(0, 2*np.pi)
                angle_z = np.random.uniform(0, 2*np.pi)
                
                # Rotation matrices
                Rx = np.array([
                    [1, 0, 0],
                    [0, np.cos(angle_x), -np.sin(angle_x)],
                    [0, np.sin(angle_x), np.cos(angle_x)]
                ])
                Ry = np.array([
                    [np.cos(angle_y), 0, np.sin(angle_y)],
                    [0, 1, 0],
                    [-np.sin(angle_y), 0, np.cos(angle_y)]
                ])
                Rz = np.array([
                    [np.cos(angle_z), -np.sin(angle_z), 0],
                    [np.sin(angle_z), np.cos(angle_z), 0],
                    [0, 0, 1]
                ])
                rot_matrix = Rz @ Ry @ Rx
                
                # Tetrahedron vertices
                vertices = np.array([
                    [1, 1, 1],
                    [-1, -1, 1],
                    [-1, 1, -1],
                    [1, -1, -1]
                ]) * (base_size/2)
                vertices = vertices @ rot_matrix.T + center
                
                # Sample uniformly in tetrahedron
                u = np.random.rand(self.n_particles, 3)
                mask = u.sum(1) > 1
                u[mask] = 1 - u[mask]
                w = np.random.rand(self.n_particles, 1)
                x_np = (u[:, 0:1] * vertices[0] + 
                        u[:, 1:2] * vertices[1] + 
                        u[:, 2:3] * vertices[2] + 
                        (1 - u.sum(1, keepdims=True)) * vertices[3])
            else:
                center = np.array([
                    np.random.uniform(min_x, max_x),
                    np.random.uniform(min_y, max_y)
                ])
                
                angle = np.random.uniform(0, 2*np.pi)
                rot_matrix = np.array([
                    [np.cos(angle), -np.sin(angle)],
                    [np.sin(angle), np.cos(angle)]
                ])
                
                vertices = np.array([
                    [-base_size/2, -base_size/2],  
                    [base_size/2, -base_size/2],   
                    [0, base_size/2]               
                ])
                vertices = vertices @ rot_matrix.T + center
                
                u = np.random.rand(self.n_particles, 2).astype(np.float32)
                mask = u.sum(1) > 1
                u[mask] = 1 - u[mask]
                x_np = (u[:, 0:1] * vertices[0] + 
                        u[:, 1:2] * vertices[1] + 
                        (1 - u.sum(1, keepdims=True)) * vertices[2])
        
        x_np = np.clip(x_np, self.lower_bound, self.upper_bound)
        
        self.x.from_numpy(x_np)
        if random_v:
            random_speed = np.random.uniform(0.05, 0.3) 
            if self.dim == 3:
                # Random direction in 3D
                theta = np.random.uniform(0, 2 * np.pi)
                phi = np.arccos(2 * np.random.uniform() - 1)
                vx = random_speed * np.sin(phi) * np.cos(theta)
                vy = random_speed * np.sin(phi) * np.sin(theta)
                vz = random_speed * np.cos(phi)
                v_np = np.full((self.n_particles, self.dim), [vx, vy, vz])
            else:
                random_angle = np.random.uniform(0, 2 * np.pi)  
                vx = random_speed * np.cos(random_angle)
                vy = random_speed * np.sin(random_angle)
                v_np = np.full((self.n_particles, self.dim), [vx, vy])
            self.v.from_numpy(v_np)
    def init_obstacles(self):
        n_lines = np.random.randint(1, 4)
        # Generate lines with y-coordinates in the lower half of the domain
        lines = []
        line_lengths = []
        for _ in range(n_lines):
            # Random start and end points within bounds
            start_x = np.random.uniform(self.lower_bound, self.upper_bound)
            start_y = np.random.uniform(self.lower_bound, self.lower_bound + 0.4)  
            end_x = np.random.uniform(self.lower_bound, self.upper_bound)
            end_y = np.random.uniform(self.lower_bound, self.lower_bound + 0.4)
            
            # Ensure line has minimum length
            while ((end_x - start_x)**2 + (end_y - start_y)**2) < 0.04:
                end_x = np.random.uniform(self.lower_bound, self.upper_bound)
                end_y = np.random.uniform(self.lower_bound, self.lower_bound + 0.4)
            
            length = np.sqrt((end_x - start_x)**2 + (end_y - start_y)**2)
            lines.append([start_x, start_y, end_x, end_y])
            line_lengths.append(length)
    
        lines_np = np.array(lines, dtype=np.float32)
        
        
        return lines_np,line_lengths
        
    def _init_water_params(self):
        """Initialize parameters specific to water simulation"""
        self.p_vol = (self.dx * 0.5) ** 2
        self.rho = 1  # Water density
        self.mass = self.p_vol * self.rho
        self.E = 400  # Young's modulus for water

        self.nu = 0.25 
        self.mu, self.lambda_ = self.E / (2 * (1 + self.nu)), self.E * self.nu / ((1 + self.nu) * (1 - 2 * self.nu))
        self.mu_b = 0.0  # Friction coefficient
        
        # Hardening model parameters
        self.h0, self.h1, self.h2, self.h3 = 35, 9, 0.2, 10
        self.pi = 3.14159265358979
        
    def _init_sand_params(self):
        """Initialize parameters specific to sand simulation"""
        self.p_vol = (self.dx * 0.5) ** 2
        self.rho = 400  # Sand density
        self.mass = self.p_vol * self.rho
        
        # Sand specific parameters
        self.E, self.nu = 4.0e5, 0.25  # Young's modulus and Poisson's ratio
        self.mu, self.lambda_ = self.E / (2 * (1 + self.nu)), self.E * self.nu / ((1 + self.nu) * (1 - 2 * self.nu))
        self.mu_b = 1.0  # Friction coefficient
        
        # Hardening model parameters
        self.h0, self.h1, self.h2, self.h3 = 35, 9, 0.2, 10
        self.pi = 3.14159265358979
        

    def initialize_fields(self):
        """Initialize all fields based on material type"""
        self._initialize_fields()
        if self.material_type == 'water':
            self._initialize_water_fields_values()
        else:
            self._initialize_sand_fields_values()

 
    @ti.kernel    
    def _initialize_water_fields_values(self):
        self.v.fill(0)
        self.J.fill(1)
        self.C.fill(0)

    def _initialize_fields(self):
        self.x = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.n_particles)
        self.v = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.n_particles)
        self.C = ti.Matrix.field(self.dim, self.dim, dtype=ti.f32, shape=self.n_particles)
        self.J = ti.field(dtype=ti.f32, shape=self.n_particles)
        
        self.F = ti.Matrix.field(self.dim, self.dim, dtype=ti.f32, shape=self.n_particles)
        self.phi = ti.field(dtype=ti.f32, shape=self.n_particles)
        self.c_C0 = ti.field(dtype=ti.f32, shape=self.n_particles)
        self.vc = ti.field(dtype=ti.f32, shape=self.n_particles)
        self.alpha = ti.field(dtype=ti.f32, shape=self.n_particles)
        self.q = ti.field(dtype=ti.f32, shape=self.n_particles)
        self.state = ti.field(dtype=int, shape=self.n_particles)
        self.color = ti.field(dtype=int, shape=self.n_particles)
        
        self.grid_v = ti.Vector.field(self.dim, dtype=ti.f32, shape=(self.n_grid, )*self.dim)
        self.grid_m = ti.field(dtype=ti.f32, shape=(self.n_grid, )*self.dim)
        self.grid_f = ti.Vector.field(self.dim, dtype=ti.f32, shape=(self.n_grid, )*self.dim)
        
        self.control_accel = ti.Vector.field(self.dim, dtype=ti.f32, shape=self.n_particles)
        
    @ti.kernel    
    def _initialize_sand_fields_values(self):
        self.F.fill(ti.Matrix.identity(ti.f32, self.dim))
        self.c_C0.fill(0.1)
        self.v.fill(0)
        self.alpha.fill(0.2)
        
    @ti.func
    def point_to_line_distance(self, p, a, b):
        """Calculate distance from point p to line segment a-b"""
        ab = b - a
        ap = p - a
        t = ap.dot(ab) / ab.norm_sqr()
        t = ti.max(0.0, ti.min(1.0, t))
        projection = a + t * ab
        return (p - projection).norm()
    
    @ti.func
    def handle_line_collision(self, pos, vel, line_start, line_end, thickness):
        """Handle collision between particle and line segment"""
        ab = line_end - line_start
        ap = pos - line_start
        t = ap.dot(ab) / ab.norm_sqr()
        t = ti.max(0.0, ti.min(1.0, t))
        closest = line_start + t * ab
        dist_vec = pos - closest
        distance = dist_vec.norm()
        
        if distance < thickness:
            normal = dist_vec.normalized()
            penetration = thickness - distance
            
            # Velocity correction (bounce)
            vel_normal = vel.dot(normal)
            if vel_normal < 0:  # Only handle velocity toward the line
                vel -= (1.0 + 0.3) * vel_normal * normal  # 0.3 is restitution
            
            # Position correction (prevent penetration)
            pos += normal * penetration * 0.8  # 0.8 is correction factor
        
        return pos, vel
    
    @ti.kernel
    def substep(self, apply_control: ti.i32, force_control: ti.i32):
        """
        Unified substep function that handles both water and sand
        
        For water simulation, only x, v, C, grid_v, grid_m are used
        For sand simulation, all parameters are used
        """
        # Reset grid
        for I in ti.grouped(self.grid_m):
            self.grid_v[I] = ti.Vector.zero(ti.f32, self.dim)
            self.grid_m[I] = 0.0
            if ti.static(self.material_type == 'sand'):
                self.grid_f[I] = ti.Vector.zero(ti.f32, self.dim)
        ti.loop_config(block_dim=self.n_grid)

        # Particle to grid (P2G)
        for p in range(self.n_particles):
            Xp = self.x[p] * self.inv_dx
            base = (Xp - 0.5).cast(ti.i32)
            fx = Xp - base.cast(ti.f32)
            
            w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
            affine = ti.Matrix.zero(ti.f32, self.dim, self.dim)
            stress = ti.Matrix.zero(ti.f32, self.dim, self.dim)
            if ti.static(self.material_type == 'water'):
                stress_value = -self.dt * 4 * self.E * self.p_vol * (self.J[p] - 1) / self.dx**2
                affine = ti.Matrix.identity(ti.f32, self.dim) * stress_value + self.mass * self.C[p]
            else:  # sand
                U, sig, V = ti.svd(self.F[p])
                inv_sig = sig.inverse()
                e = ti.Matrix.zero(ti.f32, self.dim, self.dim)
                if ti.static(self.dim == 2):
                    e[0, 0] = ti.log(sig[0, 0])
                    e[1, 1] = ti.log(sig[1, 1])
                else:  # 3D
                    e[0, 0] = ti.log(sig[0, 0])
                    e[1, 1] = ti.log(sig[1, 1])
                    e[2, 2] = ti.log(sig[2, 2])
                stress = U @ (2 * self.mu * inv_sig @ e + self.lambda_ * e.trace() * inv_sig) @ V.transpose()
                stress = (-self.p_vol * 4 * self.inv_dx * self.inv_dx) * stress @ self.F[p].transpose()
                affine = self.mass * self.C[p]

            if apply_control:
                self.v[p] += self.control_accel[p] * self.dt

            if ti.static(self.dim == 2):
                for i, j in ti.static(ti.ndrange(3, 3)):
                    offset = ti.Vector([i, j])
                    dpos = (offset.cast(ti.f32) - fx) * self.dx
                    weight = w[i][0] * w[j][1]
                    cell = base + offset
                    cell.x = ti.max(0, ti.min(self.n_grid - 1, cell.x))
                    cell.y = ti.max(0, ti.min(self.n_grid - 1, cell.y))
                    
                    self.grid_v[cell] += weight * (self.mass * self.v[p] + affine @ dpos)
                    self.grid_m[cell] += weight * self.mass
                    if ti.static(self.material_type == 'sand'):
                        self.grid_f[cell] += weight * stress @ dpos
            else:  # 3D
                for i, j, k in ti.static(ti.ndrange(3, 3, 3)):
                    offset = ti.Vector([i, j, k])
                    dpos = (offset.cast(ti.f32) - fx) * self.dx
                    weight = w[i][0] * w[j][1] * w[k][2]
                    cell = base + offset
                    cell.x = ti.max(0, ti.min(self.n_grid - 1, cell.x))
                    cell.y = ti.max(0, ti.min(self.n_grid - 1, cell.y))
                    cell.z = ti.max(0, ti.min(self.n_grid - 1, cell.z))
                    
                    self.grid_v[cell] += weight * (self.mass * self.v[p] + affine @ dpos)
                    self.grid_m[cell] += weight * self.mass
                    if ti.static(self.material_type == 'sand'):
                        self.grid_f[cell] += weight * stress @ dpos
        
        # Update grid momentum
        for I in ti.grouped(self.grid_m):
            if self.grid_m[I] > 0:
                self.grid_v[I] = (1 / self.grid_m[I]) * self.grid_v[I]
                
                if ti.static(self.material_type == 'water'):
                    self.grid_v[I].y += self.dt * self.gravity.y
                    if force_control:
                        self.grid_v[I].x += self.dt * self.force_x
                        self.grid_v[I].y += self.dt * self.force_y
                        if ti.static(self.dim == 3):
                            self.grid_v[I].z += self.dt * self.force_z
                else:  # sand
                    self.grid_v[I] += self.dt * (self.gravity + self.grid_f[I] / self.grid_m[I])
                    if force_control:
                        self.grid_v[I].x += self.dt * self.force_x
                        self.grid_v[I].y += self.dt * self.force_y
                        if ti.static(self.dim == 3):
                            self.grid_v[I].z += self.dt * self.force_z
            # Boundary conditions
            # cond = (I < self.bound) & (self.grid_v[I] < 0) | (I > self.n_grid - self.bound) & (self.grid_v[I] > 0)  # boundary conditions
            # self.grid_v[I] = ti.select(cond, 0, self.grid_v[I])
            normal = ti.Vector.zero(ti.f32, self.dim)
            if ti.static(self.dim == 2):
                if I.x < self.bound and self.grid_v[I][0] < 0: normal = [1.0, 0.0]
                if I.x > self.n_grid - self.bound and self.grid_v[I][0] > 0: normal = [-1.0, 0.0]
                if I.y < self.bound and self.grid_v[I][1] < 0: normal = [0.0, 1.0]
                if I.y > self.n_grid - self.bound and self.grid_v[I][1] > 0: normal = [0.0, -1.0]
            else:  # 3D
                if I.x < self.bound and self.grid_v[I][0] < 0: normal = [1.0, 0.0, 0.0]
                if I.x > self.n_grid - self.bound and self.grid_v[I][0] > 0: normal = [-1.0, 0.0, 0.0]
                if I.y < self.bound and self.grid_v[I][1] < 0: normal = [0.0, 1.0, 0.0]
                if I.y > self.n_grid - self.bound and self.grid_v[I][1] > 0: normal = [0.0, -1.0, 0.0]
                if I.z < self.bound and self.grid_v[I][2] < 0: normal = [0.0, 0.0, 1.0]
                if I.z > self.n_grid - self.bound and self.grid_v[I][2] > 0: normal = [0.0, 0.0, -1.0]

            if ti.static(self.dim == 2):
                if not (normal[0] == 0 and normal[1] == 0):
                    s = self.grid_v[I].dot(normal)
                    if s <= 0:
                        v_normal = s * normal
                        v_tangent = self.grid_v[I] - v_normal
                        vt = v_tangent.norm()
                        if vt > 1e-12: 
                            self.grid_v[I] = v_tangent - (vt if vt < -self.mu_b * s else -self.mu_b * s) * (v_tangent / vt)
            else:
                if not (normal[0] == 0 and normal[1] == 0 and normal[2] == 0):
                    s = self.grid_v[I].dot(normal)
                    if s <= 0:
                        v_normal = s * normal
                        v_tangent = self.grid_v[I] - v_normal
                        vt = v_tangent.norm()
                        if vt > 1e-12: 
                            self.grid_v[I] = v_tangent - (vt if vt < -self.mu_b * s else -self.mu_b * s) * (v_tangent / vt)
        ti.loop_config(block_dim=self.n_grid)

        # Grid to particle (G2P)
        for p in range(self.n_particles):
            Xp = self.x[p] * self.inv_dx
            base = (Xp - 0.5).cast(ti.i32)
            fx = Xp - base.cast(ti.f32)
            w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1.0) ** 2, 0.5 * (fx - 0.5) ** 2]
            
            new_v = ti.Vector.zero(ti.f32, self.dim)
            new_C = ti.Matrix.zero(ti.f32, self.dim, self.dim)
            
            if ti.static(self.dim == 2):
                for i, j in ti.static(ti.ndrange(3, 3)):
                    dpos = ti.Vector([i, j]).cast(ti.f32) - fx
                    cell = base + ti.Vector([i, j])
                    cell.x = ti.max(0, ti.min(self.n_grid - 1, cell.x))
                    cell.y = ti.max(0, ti.min(self.n_grid - 1, cell.y))
                    g_v = self.grid_v[cell]
                    weight = w[i][0] * w[j][1]
                    new_v += weight * g_v
                    new_C += 4 * self.inv_dx * weight * g_v.outer_product(dpos)
            else:  # 3D
                for i, j, k in ti.static(ti.ndrange(3, 3, 3)):
                    dpos = ti.Vector([i, j, k]).cast(ti.f32) - fx
                    cell = base + ti.Vector([i, j, k])
                    cell.x = ti.max(0, ti.min(self.n_grid - 1, cell.x))
                    cell.y = ti.max(0, ti.min(self.n_grid - 1, cell.y))
                    cell.z = ti.max(0, ti.min(self.n_grid - 1, cell.z))
                    g_v = self.grid_v[cell]
                    weight = w[i][0] * w[j][1] * w[k][2]
                    new_v += weight * g_v
                    new_C += 4 * self.inv_dx * weight * g_v.outer_product(dpos)

            if ti.static(self.material_type == 'sand'):
                self.F[p] = (ti.Matrix.identity(ti.f32, self.dim) + self.dt * new_C) @ self.F[p]
            
            self.v[p], self.C[p] = new_v, new_C
            self.x[p] += self.dt * self.v[p]

            # Handle line collisions
            if ti.static(self.dim == 2):
                if self.num_lines[None] > 0:
                    num_lines = self.num_lines[None]
                    for n in range(num_lines):
                        line_start = self.lines[n][:2]
                        line_end = self.lines[n][2:4]
                        self.x[p], self.v[p] = self.handle_line_collision(self.x[p], self.v[p], line_start, line_end, self.line_thickness)
            # Boundary checks
            for d in ti.static(range(self.dim)):
                if self.x[p][d] < self.lower_bound or self.x[p][d] > self.upper_bound:
                    self.x[p][d] = ti.max(self.lower_bound, ti.min(self.upper_bound, self.x[p][d]))

            if ti.static(self.material_type == 'sand'):
                # Plastic projection for sand
                U, sig, V = ti.svd(self.F[p])
                e = ti.Matrix.zero(ti.f32, self.dim, self.dim)
                if ti.static(self.dim == 2):
                    e = ti.Matrix([[ti.log(sig[0, 0]), 0.0], [0.0, ti.log(sig[1, 1])]])
                else:
                    e = ti.Matrix([[ti.log(sig[0, 0]), 0, 0], [0, ti.log(sig[1, 1]), 0], [0, 0, ti.log(sig[2, 2])]])
                new_e, dq = self.project(e, p)
                self.hardening(dq, p)
                new_F = ti.Matrix.zero(ti.f32, self.dim, self.dim)
                if ti.static(self.dim == 2):
                    new_F = U @ ti.Matrix([[ti.exp(new_e[0, 0]), 0.0], [0.0, ti.exp(new_e[1, 1])]]) @ V.transpose()
                else:
                    new_F = U @ ti.Matrix([[ti.exp(new_e[0, 0]), 0.0, 0.0], 
                                        [0.0, ti.exp(new_e[1, 1]), 0.0], 
                                        [0.0, 0.0, ti.exp(new_e[2, 2])]]) @ V.transpose()   
                self.vc[p] += -ti.log(new_F.determinant()) + ti.log(self.F[p].determinant())
                self.F[p] = new_F
            else:
                # Water-specific updates
                self.J[p] *= 1 + self.dt * new_C.trace()
    
    @ti.func
    def project(self, e0, p):
        """Plasticity projection for sand"""
        e = e0 + self.vc[p] / self.dim * ti.Matrix.identity(ti.f32, self.dim)
        e += (self.c_C0[p] * (1.0 - self.phi[p])) / (self.dim * self.alpha[p]) * ti.Matrix.identity(ti.f32, self.dim)
        ehat = e - e.trace() / self.dim * ti.Matrix.identity(ti.f32, self.dim)
        
        Fnorm = ti.sqrt(ehat[0, 0] ** 2 + ehat[1, 1] ** 2)
        yp = Fnorm + (self.dim * self.lambda_ + 2 * self.mu) / (2 * self.mu) * e.trace() * self.alpha[p]
        
        new_e = ti.Matrix.zero(ti.f32, self.dim, self.dim)
        delta_q = 0.0
        
        if Fnorm <= 0 or e.trace() > 0:
            new_e = ti.Matrix.zero(ti.f32, self.dim, self.dim)
            if self.dim==2:
                delta_q = ti.sqrt(e[0, 0] ** 2 + e[1, 1] ** 2)
            else:
                delta_q = ti.sqrt(e[0, 0]**2 + e[1, 1]**2 + e[2, 2]**2)
            self.state[p] = 0
        elif yp <= 0:
            new_e = e0
            delta_q = 0
            self.state[p] = 1
        else:
            new_e = e - yp / Fnorm * ehat
            delta_q = yp
            self.state[p] = 2

        return new_e, delta_q
    
    @ti.func
    def hardening(self, dq, p):
        """Hardening model for sand"""
        self.q[p] += dq
        phi_angle = self.h0 + (self.h1 * self.q[p] - self.h3) * ti.exp(-self.h2 * self.q[p])
        phi_angle = phi_angle / 180 * self.pi
        sin_phi = ti.sin(phi_angle)
        self.alpha[p] = ti.sqrt(2 / 3) * (2 * sin_phi) / (3 - sin_phi)
    
    @ti.kernel
    def update_color(self):
        for i in range(self.n_particles):
            t = self.phi[i]
            r = ti.cast((0.521 * (1 - t) + 0.318 * t) * 0xFF, ti.i32)
            g = ti.cast((0.368 * (1 - t) + 0.223 * t) * 0xFF, ti.i32)
            b = ti.cast((0.259 * (1 - t) + 0.157 * t) * 0xFF, ti.i32)
            self.color[i] = (r << 16) + (g << 8) + b

    def T(self, a):
        if self.dim == 2:
            return a

        phi, theta = np.radians(28), np.radians(32)

        a = a - 0.5
        x, y, z = a[:, 0], a[:, 1], a[:, 2]
        cp, sp = np.cos(phi), np.sin(phi)
        ct, st = np.cos(theta), np.sin(theta)
        x, z = x * cp + z * sp, z * cp - x * sp
        u, v = x, y * ct + z * st
        return np.array([u, v]).swapaxes(0, 1) + 0.5
    def save_video(self, positions, output_path, fps):
        gui = ti.GUI(f"Simulation", background_color=0xFFFFFF, show_gui=False)
        fps = 30
        writer = imageio.get_writer(output_path, fps=fps)
        if self.material_type == 'sand':
            self.update_color()

        for frame_pos in positions:
            gui.clear(0xFFFFFF)
            if self.num_lines[None] > 0:
                # gui.lines(begin=self.lines.to_numpy()[:,:2], end=self.lines.to_numpy()[:,2:4], color=0x000000, radius=3) 
                gui.circles(self.T(self.ob_x.to_numpy()), radius=3, color=0x000000)
            if self.material_type == 'sand':
                gui.circles(self.T(frame_pos), radius=1.5, color=self.color.to_numpy())
            elif self.material_type == 'water':
                gui.circles(self.T(frame_pos), radius=1.5, color=0x0066FF)
            frame = gui.get_image()
            frame = np.rot90(frame, k=1)
            frame = (frame * 255).astype(np.uint8)
            writer.append_data(frame)
        writer.close()
        gui.close()
        print(f"Video saved to {output_path}")

    @ti.kernel
    def get_control_force_magnitude(self) -> ti.f32:
        total = 0.0
        for p in self.control_accel:
            total += self.control_accel[p].norm()
        return total / self.n_particles