import taichi as ti
import numpy as np
import imageio
import json
import os
from math import sqrt, cos, sin, pi

ti.init(arch=ti.gpu, debug=True, default_fp=ti.f32)
SEED = 42
np.random.seed(SEED)
PARTICLE_TYPE={"Rigid":0, "Droplet": 1, "Boundary": 3, "Water": 5, "Sand": 6, "Goop": 7}


max_particles = 3300  
min_particles = 1000   
n_grid = 128 
dx, inv_dx = 1 / n_grid, float(n_grid)
dt = 2.5e-4
gravity = ti.Vector([0, -9.8])
dim = 2  
lower_bound = 0.1
upper_bound = 0.9
bound = int(0.115 * n_grid)

save_data = {}
velocities = []
accelerations = []
total_frames = 400
steps_per_frame = 10

# Line obstacle parameters (will be randomized per simulation)
n_obstacle_particles = 300  # Total particles
line_thickness = 0.008
# Store lines for visualization
current_lines = []  # Will store tuples of (start, end, length)



p_vol = (dx * 0.5) ** 2
rho = 400
mass = p_vol * rho
E, nu = 4.0e5, 0.25  
mu, lambda_ = E / (2 * (1 + nu)), E * nu / ((1 + nu) * (1 - 2 * nu))  
mu_b = 1.0 


h0, h1, h2, h3 = 35, 9, 0.2, 10
pi = 3.14159265358979

def initialize_fields(n_particles):
    x = ti.Vector.field(2, dtype=ti.f32, shape=n_particles)  
    v = ti.Vector.field(2, dtype=ti.f32, shape=n_particles) 
    C = ti.Matrix.field(2, 2, dtype=ti.f32, shape=n_particles)  
    F = ti.Matrix.field(2, 2, dtype=ti.f32, shape=n_particles)  
    phi = ti.field(dtype=ti.f32, shape=n_particles)  
    c_C0 = ti.field(dtype=ti.f32, shape=n_particles)  
    vc = ti.field(dtype=ti.f32, shape=n_particles)  
    alpha = ti.field(dtype=ti.f32, shape=n_particles) 
    q = ti.field(dtype=ti.f32, shape=n_particles)  
    state = ti.field(dtype=int, shape=n_particles)  
    color = ti.field(dtype=int, shape=n_particles)  
    
    grid_v = ti.Vector.field(2, dtype=ti.f32, shape=(n_grid, n_grid)) 
    grid_m = ti.field(dtype=ti.f32, shape=(n_grid, n_grid)) 
    grid_f = ti.Vector.field(2, dtype=ti.f32, shape=(n_grid, n_grid))
    
    return x, v, C, F, phi, c_C0, vc, alpha, q, state, color, grid_v, grid_m, grid_f
    
@ti.func
def point_to_line_distance(p: ti.types.vector(2, ti.f32), 
                         a: ti.types.vector(2, ti.f32), 
                         b: ti.types.vector(2, ti.f32)) -> ti.f32:
    ab = b - a
    ap = p - a
    t = ap.dot(ab) / ab.norm_sqr()
    t = ti.max(0.0, ti.min(1.0, t))
    projection = a + t * ab
    return (p - projection).norm()

@ti.func
def handle_line_collision(pos: ti.types.vector(2, ti.f32),
                        vel: ti.types.vector(2, ti.f32),
                        line_start: ti.types.vector(2, ti.f32),
                        line_end: ti.types.vector(2, ti.f32),
                        thickness: ti.f32) -> ti.types.vector(2, ti.f32):
    ab = line_end - line_start
    ap = pos - line_start
    t = ap.dot(ab) / ab.norm_sqr()
    t = ti.max(0.0, ti.min(1.0, t))
    closest = line_start + t * ab
    dist_vec = pos - closest
    distance = dist_vec.norm()
    

    if distance < thickness:
        normal = dist_vec.normalized()
        penetration = thickness - distance
        
        vel_normal = vel.dot(normal)
        if vel_normal < 0:  
            vel -= (1.0 + 0.3) * vel_normal * normal 
        
        pos += normal * penetration * 0.8  
    
    return pos, vel

@ti.func
def h_s(z):
    ret = 0.0
    if z < 0: 
        ret = 1
    if z > 1: 
        ret = 0
    ret = 1 - 10 * (z ** 3) + 15 * (z ** 4) - 6 * (z ** 5)
    return ret

@ti.func
def project(e0: ti.template(), p: int, vc: ti.template(), phi: ti.template(), c_C0: ti.template(), alpha: ti.template(), 
            dim: int, lambda_: float, mu: float, state: ti.template()):
    e = e0 + vc[p] / dim * ti.Matrix.identity(float, 2)
    e += (c_C0[p] * (1.0 - phi[p])) / (dim * alpha[p]) * ti.Matrix.identity(float, 2)
    ehat = e - e.trace() / dim * ti.Matrix.identity(float, 2)
    Fnorm = ti.sqrt(ehat[0, 0] ** 2 + ehat[1, 1] ** 2)
    yp = Fnorm + (dim * lambda_ + 2 * mu) / (2 * mu) * e.trace() * alpha[p]
    
    new_e = ti.Matrix.zero(float, 2, 2)
    delta_q = 0.0
    
    if Fnorm <= 0 or e.trace() > 0:
        new_e = ti.Matrix.zero(float, 2, 2)
        delta_q = ti.sqrt(e[0, 0] ** 2 + e[1, 1] ** 2)
        state[p] = 0
    elif yp <= 0:
        new_e = e0
        delta_q = 0
        state[p] = 1
    else:
        new_e = e - yp / Fnorm * ehat
        delta_q = yp
        state[p] = 2

    return new_e, delta_q

@ti.func
def hardening(dq: float, p: int, q: ti.template(), alpha: ti.template(), h0: float, h1: float, h2: float, h3: float, pi: float):
    q[p] += dq
    phi_angle = h0 + (h1 * q[p] - h3) * ti.exp(-h2 * q[p])
    phi_angle = phi_angle / 180 * pi
    sin_phi = ti.sin(phi_angle)
    alpha[p] = ti.sqrt(2 / 3) * (2 * sin_phi) / (3 - sin_phi)

@ti.kernel
def substep(x: ti.template(), v: ti.template(), C: ti.template(), F: ti.template(),
           phi: ti.template(), c_C0: ti.template(), vc: ti.template(), 
           alpha: ti.template(), q: ti.template(), state: ti.template(),
           grid_v: ti.template(), grid_m: ti.template(), grid_f: ti.template(),
           lines: ti.types.ndarray(), n_particles: ti.i32):

    for i, j in grid_m:
        grid_v[i, j] = [0.0, 0.0]
        grid_m[i, j] = 0.0
        grid_f[i, j] = [0.0, 0.0]


    for p in range(n_particles):
        base = (x[p] * inv_dx - 0.5).cast(ti.i32)
        fx = x[p] * inv_dx - base.cast(ti.f32)
        
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        
        U, sig, V = ti.svd(F[p])
        inv_sig = sig.inverse()
        e = ti.Matrix([[ti.log(sig[0, 0]), 0.0], [0.0, ti.log(sig[1, 1])]])
        stress = U @ (2 * mu * inv_sig @ e + lambda_ * e.trace() * inv_sig) @ V.transpose()
        stress = (-p_vol * 4 * inv_dx * inv_dx) * stress @ F[p].transpose()
        
        mass = p_vol * rho
        affine = mass * C[p]
        
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset.cast(ti.f32) - fx) * dx
            weight = w[i][0] * w[j][1]
            cell = base + offset
            cell.x = ti.max(0, ti.min(n_grid - 1, cell.x))
            cell.y = ti.max(0, ti.min(n_grid - 1, cell.y))
            
            grid_v[cell] += weight * (mass * v[p] + affine @ dpos)
            grid_m[cell] += weight * mass
            grid_f[cell] += weight * stress @ dpos

    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] = (1 / grid_m[i, j]) * grid_v[i, j]
            grid_v[i, j] += dt * (gravity + grid_f[i, j] / grid_m[i, j])

            normal = ti.Vector.zero(ti.f32, 2)
            if i < bound and grid_v[i, j][0] < 0: normal = [1.0, 0.0]
            if i > n_grid - bound and grid_v[i, j][0] > 0: normal = [-1.0, 0.0]
            if j < bound and grid_v[i, j][1] < 0: normal = [0.0, 1.0]
            if j > n_grid - bound and grid_v[i, j][1] > 0: normal = [0.0, -1.0]
            
            if not (normal[0] == 0 and normal[1] == 0):
                s = grid_v[i, j].dot(normal)
                if s <= 0:
                    v_normal = s * normal
                    v_tangent = grid_v[i, j] - v_normal
                    vt = v_tangent.norm()
                    if vt > 1e-12: 
                        grid_v[i, j] = v_tangent - (vt if vt < -mu_b * s else -mu_b * s) * (v_tangent / vt)

    for p in range(n_particles):
        base = (x[p] * inv_dx - 0.5).cast(ti.i32)
        fx = x[p] * inv_dx - base.cast(ti.f32)
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1.0) ** 2, 0.5 * (fx - 0.5) ** 2]
        
        new_v = ti.Vector.zero(ti.f32, 2)
        new_C = ti.Matrix.zero(ti.f32, 2, 2)
        
        for i, j in ti.static(ti.ndrange(3, 3)):
            dpos = ti.Vector([i, j]).cast(ti.f32) - fx
            cell = base + ti.Vector([i, j])
            cell.x = ti.max(0, ti.min(n_grid - 1, cell.x))
            cell.y = ti.max(0, ti.min(n_grid - 1, cell.y))
            g_v = grid_v[cell]
            weight = w[i][0] * w[j][1]
            new_v += weight * g_v
            new_C += 4 * inv_dx * weight * g_v.outer_product(dpos)
        
        F[p] = (ti.Matrix.identity(ti.f32, 2) + dt * new_C) @ F[p]
        v[p], C[p] = new_v, new_C
        x[p] += dt * v[p]

        num_lines = lines.shape[0] // 4
        for n in range(num_lines):
            idx = n * 4
            line_start = ti.Vector([lines[idx], lines[idx+1]])
            line_end = ti.Vector([lines[idx+2], lines[idx+3]])
            x[p], v[p] = handle_line_collision(x[p], v[p], line_start, line_end, line_thickness)

        if x[p].x < lower_bound or x[p].x > upper_bound:
            x[p].x = ti.max(lower_bound, ti.min(upper_bound, x[p].x))
        if x[p].y < lower_bound or x[p].y > upper_bound:
            x[p].y = ti.max(lower_bound, ti.min(upper_bound, x[p].y)) 

        U, sig, V = ti.svd(F[p])
        e = ti.Matrix([[ti.log(sig[0, 0]), 0.0], [0.0, ti.log(sig[1, 1])]])
        new_e, dq = project(e, p, vc, phi, c_C0, alpha, dim, lambda_, mu, state)
        hardening(dq, p, q, alpha, h0, h1, h2, h3, pi)
        new_F = U @ ti.Matrix([[ti.exp(new_e[0, 0]), 0.0], [0.0, ti.exp(new_e[1, 1])]]) @ V.transpose()
        vc[p] += -ti.log(new_F.determinant()) + ti.log(F[p].determinant())
        F[p] = new_F

@ti.kernel
def initialize_fields_values( v: ti.template(), F: ti.template(),
                           c_C0: ti.template(), alpha: ti.template(), n_particles: ti.i32):
    for i in range(n_particles):
        F[i] = ti.Matrix([[1.0, 0.0], [0.0, 1.0]])
        c_C0[i] = 0.1
        v[i] = ti.Vector([0.0, 0.0])
        alpha[i] = 0.2

def update_color(color: ti.template(), phi: ti.template(), n_particles: int):
    @ti.kernel
    def _update_color(color: ti.template(), phi: ti.template(), n_particles: ti.i32):
        for i in range(n_particles):
            t = phi[i]
            r = ti.cast((0.521 * (1 - t) + 0.318 * t) * 0xFF, ti.i32)
            g = ti.cast((0.368 * (1 - t) + 0.223 * t) * 0xFF, ti.i32)
            b = ti.cast((0.259 * (1 - t) + 0.157 * t) * 0xFF, ti.i32)
            color[i] = (r << 16) + (g << 8) + b
    _update_color(color, phi, n_particles)
def generate_random_lines():
    global current_lines
    current_lines = []
    
    
    # Random number of lines (1-3)
    n_lines = np.random.randint(1, 4)
    
    # Generate lines with y-coordinates in the lower half of the domain
    lines = []
    line_lengths = []
    for _ in range(n_lines):
        # Random start and end points within bounds
        start_x = np.random.uniform(lower_bound, upper_bound)
        start_y = np.random.uniform(lower_bound, lower_bound + 0.4)
        end_x = np.random.uniform(lower_bound, upper_bound)
        end_y = np.random.uniform(lower_bound, lower_bound + 0.4)
        
        # Ensure line has minimum length
        while ((end_x - start_x)**2 + (end_y - start_y)**2) < 0.04:
            end_x = np.random.uniform(lower_bound, upper_bound)
            end_y = np.random.uniform(lower_bound, lower_bound + 0.4)
        
        length = np.sqrt((end_x - start_x)**2 + (end_y - start_y)**2)
        lines.extend([start_x, start_y, end_x, end_y])
        current_lines.append(([start_x, start_y], [end_x, end_y], length))
        line_lengths.append(length)
    
    return np.array(lines, dtype=np.float32), line_lengths
def initialize_particles(shape_type, n_particles: int):
    
    available_size = upper_bound - lower_bound

    lines, _ = generate_random_lines()
    num_lines = lines.shape[0] // 4
    

    max_line_y = 0
    for n in range(num_lines):
        idx = n * 4
        start_y = lines[idx+1]
        end_y = lines[idx+3]
        max_line_y = max(max_line_y, start_y, end_y)

    min_particle_y = min(max(max_line_y + 0.05, lower_bound), upper_bound - 0.2)

    if shape_type == 0:  

        max_width = available_size * 0.5
        max_height = available_size * 0.5
        width = np.random.uniform(0.1, max_width)
        height = np.random.uniform(0.1, max_height)
        

        min_x = lower_bound + width/2
        max_x = upper_bound - width/2
        min_y = min_particle_y + height/2
        max_y = min(upper_bound - height/2, min_y + 0.3)  
        center_x = np.random.uniform(min_x, max_x)
        center_y = np.random.uniform(min_y, max_y)
        
 
        x_np = np.random.rand(n_particles, 2).astype(np.float32) * [width, height] + [center_x - width/2, center_y - height/2]

    elif shape_type == 1:  

        max_radius = available_size * 0.25
        radius = np.random.uniform(0.05, max_radius)
        

        min_x = lower_bound + radius
        max_x = upper_bound - radius
        min_y = min_particle_y + radius
        max_y = min(upper_bound - radius, min_y + 0.3)  
        center = np.array([
            np.random.uniform(min_x, max_x),
            np.random.uniform(min_y, max_y)
        ])
        

        angles = np.random.rand(n_particles) * 2 * np.pi
        rads = np.sqrt(np.random.rand(n_particles)) * radius
        x_np = np.column_stack([
            center[0] + rads * np.cos(angles),
            center[1] + rads * np.sin(angles)
        ])

    elif shape_type == 2:

        max_base_size = available_size * 0.35
        base_size = np.random.uniform(0.1, max_base_size)
        

        safe_margin = base_size * 0.6
        min_x = lower_bound + safe_margin
        max_x = upper_bound - safe_margin
        min_y = min_particle_y + safe_margin
        max_y = min(upper_bound - safe_margin, min_y + 0.3) 
        center = np.array([
            np.random.uniform(min_x, max_x),
            np.random.uniform(min_y, max_y)
        ])
        

        angle = np.random.uniform(0, 2*np.pi)
        rot_matrix = np.array([
            [np.cos(angle), -np.sin(angle)],
            [np.sin(angle), np.cos(angle)]
        ])
        

        vertices = np.array([
            [-base_size/2, -base_size/2],
            [base_size/2, -base_size/2],
            [0, base_size/2]
        ])
        

        vertices = vertices @ rot_matrix.T + center
        

        u = np.random.rand(n_particles, 2)
        mask = u.sum(1) > 1
        u[mask] = 1 - u[mask]
        x_np = (u[:, 0:1] * vertices[0] + 
                u[:, 1:2] * vertices[1] + 
                (1 - u.sum(1, keepdims=True)) * vertices[2])
    
    x_np = np.clip(x_np, lower_bound, upper_bound)
    
    return x_np

def run_simulation(sim_id, output_dir="Water_drop_simulations"):
    
    
    
    n_particles = np.random.randint(min_particles, max_particles + 1)
    
    x, v, C, F, phi, c_C0, vc, alpha, q, state, color, grid_v, grid_m, grid_f = initialize_fields(n_particles)

    breakpoint()
    shape_type = np.random.randint(0, 3)
    x_np = initialize_particles(shape_type, n_particles)
    x.from_numpy(x_np)
    initialize_fields_values( v, F, c_C0, alpha, n_particles)
    
    
    # Generate random lines for this simulation, using initial particle positions
    lines, line_lengths = generate_random_lines()
    # Create obstacle particles from lines
    total_length = sum(line_lengths)
    particles_per_line = [max(1, int(n_obstacle_particles * length / total_length)) for length in line_lengths]
    diff = n_obstacle_particles - sum(particles_per_line)
    if diff != 0:
        particles_per_line[0] += diff
    
    obstacle_positions = []
    for i, line in enumerate(current_lines):
        start, end, _ = line
        t = np.linspace(0, 1, particles_per_line[i])
        obstacle_pos = np.outer(1 - t, start) + np.outer(t, end)
        obstacle_positions.append(obstacle_pos)
    obstacle_positions = np.concatenate(obstacle_positions)

    # Data storage
    positions = []
    velocities = []
    accelerations = []
    particle_types = np.concatenate([
        PARTICLE_TYPE['Boundary'] * np.ones(n_obstacle_particles),
        PARTICLE_TYPE['Sand'] * np.ones(n_particles)
    ])
    prev_vel = v.to_numpy()

    for frame in range(total_frames):
        for s in range(steps_per_frame):
            substep(x, v, C, F, phi, c_C0, vc, alpha, q, state, 
                   grid_v, grid_m, grid_f, lines, n_particles)
        pos = x.to_numpy()
        vel = v.to_numpy()
        
        vel = vel * dt* steps_per_frame
        # Compute acceleration (finite difference)
        acc = (vel - prev_vel)
        
        if np.isnan(vel).any():
            return None, None, None

        # Combine obstacle and sand particles (obstacles first)
        full_pos = np.concatenate([obstacle_positions, pos])
        full_vel = np.concatenate([np.zeros((n_obstacle_particles, 2)), vel])  # Obstacle vel=0
        full_acc = np.concatenate([np.zeros((n_obstacle_particles, 2)), acc])  # Obstacle acc=0
        
        positions.append(full_pos)
        velocities.append(full_vel)
        accelerations.append(full_acc)
        prev_vel = vel.copy()
        

    
    # Save data in requested format
    trajectory = np.empty(2, dtype=object)
    trajectory[0] = np.array(positions)
    trajectory[1] = particle_types
    save_data[f'simulation_trajectory_{sim_id}'] = trajectory
    # save_data[f'simulation_trajectory_{sim_id}'] = np.array(
    #     [np.array(positions), particle_types], 
    #     dtype=object
    # )
      
    # Optional: Save video
    if sim_id < 10:  # Save videos for first 10 sims only to save space
        save_video(sim_id, positions, output_dir, color, phi, n_particles)

    return velocities,accelerations, current_lines

def save_video(sim_id: int, positions: list, output_dir: str, 
              color: ti.template(), phi: ti.template(), n_particles: int):
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"sim_{sim_id}.mp4")
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 96
    writer = imageio.get_writer(output_path, fps=fps)
    
    update_color(color, phi, n_particles)
    color_np = color.to_numpy()
    
    for frame_pos in positions:
        gui.clear(0xFFFFFF)
        obstacle_pos = frame_pos[:n_obstacle_particles]
        sand_pos = frame_pos[n_obstacle_particles:]
        
        gui.circles(sand_pos, radius=1.5, color=color_np)
        gui.circles(obstacle_pos, radius=3, color=0x000000)
        
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()

if __name__ == "__main__":

    num_simulations = 1200
    output_dir = 'SandRamps_simulations--'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    all_obstacles = []  # To store obstacle configurations

    idx = 0
    # for i in range(num_simulations):
    while idx < num_simulations:
        print(f"Running simulation {idx}/{num_simulations}")
        velocities, accelerations, obstacles = run_simulation(idx,output_dir)
        if velocities is None:
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        # if result[0] is None:
        #     continue
        # velocities, accelerations, obstacles = result
        all_velocities.append(velocities)
        all_accelerations.append(accelerations)
        all_obstacles.append(obstacles)
        idx += 1

    vel_stats = {
        "mean": np.mean(np.array([np.mean(np.array(v)[:, n_obstacle_particles:, :], axis=(0, 1)) for v in all_velocities]), axis=(0)),
        "std": np.std(np.array([np.std(np.array(v)[:, n_obstacle_particles:, :], axis=(0, 1)) for v in all_velocities]), axis=(0))
    }
    
    acc_stats = {
        "mean": np.mean(np.array([np.mean(np.array(a)[:, n_obstacle_particles:, :], axis=(0, 1)) for a in all_accelerations]), axis=(0)),
        "std": np.mean(np.array([np.std(np.array(a)[:, n_obstacle_particles:, :], axis=(0, 1)) for a in all_accelerations]), axis=(0))
    }
   
    # Prepare obstacles metadata
    obstacles_metadata = []
    for obs in all_obstacles:
        obs_config = []
        for line in obs:
            obs_config.append([line[0], line[1]])  # [start, end]
        obstacles_metadata.append(obs_config)

    metadata = {
        "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
        "sequence_length": total_frames,
        "default_connectivity_radius": 0.015,
        "dim": dim,
        "dt": dt*steps_per_frame,
        "vel_mean": vel_stats['mean'].tolist(),
        "vel_std": vel_stats['std'].tolist(),
        "acc_mean": acc_stats['mean'].tolist(),
        "acc_std": acc_stats['std'].tolist(),
        "obstacles_configs": obstacles_metadata # [[[,],[,]],], 
    }
    
    # Save to files
    all_sim_ids = list(save_data.keys())
    train_ids = all_sim_ids[:1000]   
    valid_ids = all_sim_ids[1100:]  
    test_ids = all_sim_ids[1000:1100] 

    train_data = {k: save_data[k] for k in train_ids}
    valid_data = {k: save_data[k] for k in valid_ids}
    test_data = {k: save_data[k] for k in test_ids}

    np.savez(os.path.join(output_dir, 'train.npz'), **train_data)
    np.savez(os.path.join(output_dir, 'valid.npz'), **valid_data)
    np.savez(os.path.join(output_dir, 'test.npz'), **test_data)
    # np.savez(os.path.join(output_dir, f'all_elements.npz'), **save_data)
    with open(os.path.join(output_dir, f'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    print("All simulations completed!")

    