import taichi as ti
import imageio  
import numpy as np
import json
import os
from math import sqrt, cos, sin, pi


PARTICLE_TYPE={"Rigid":0, "Droplet": 1, "Boundary": 3, "Water": 5, "Sand": 6, "Goop": 7}

# Initialize Taichi
ti.init(arch=ti.gpu, default_fp=ti.f32)

SEED = 42
np.random.seed(SEED)

# Simulation parameters
max_particles = 1900  
min_particles = 500  
n_grid = 128
dx = 1 / n_grid
dt = 2.5e-4
dim = 2
p_rho = 1
p_vol = (dx * 0.5) ** 2
p_mass = p_vol * p_rho
gravity = 9.8
lower_bound = 0.1
upper_bound = 0.9
bound = int(0.115 * n_grid)
E = 400

save_data = {}
velocities = []
accelerations = []
total_frames = 1000
steps_per_frame = 10

def initialize_fields(n_particles):
    # Field definitions
    x = ti.Vector.field(2, dtype=ti.f32,  shape=n_particles)
    v = ti.Vector.field(2, dtype=ti.f32,  shape=n_particles)
    C = ti.Matrix.field(2, 2, dtype=ti.f32,  shape=n_particles)
    J = ti.field(dtype=ti.f32,  shape=n_particles)

    grid_v = ti.Vector.field(2, dtype=ti.f32,  shape=(n_grid, n_grid))
    grid_m = ti.field(dtype=ti.f32,  shape=(n_grid, n_grid))

    return x, v, C, J, grid_v, grid_m



@ti.kernel
def substep(x: ti.template(), v: ti.template(), C: ti.template(), J: ti.template(),
        grid_v: ti.template(), grid_m: ti.template(), n_particles: int):
    for i, j in grid_m:
        grid_v[i, j] = [0, 0]
        grid_m[i, j] = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        stress = -dt * 4 * E * p_vol * (J[p] - 1) / dx**2
        affine = ti.Matrix([[stress, 0], [0, stress]]) + p_mass * C[p]
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            grid_v[base + offset] += weight * (p_mass * v[p] + affine @ dpos)
            grid_m[base + offset] += weight * p_mass
    for i, j in grid_m:
        if grid_m[i, j] > 0:
            grid_v[i, j] /= grid_m[i, j]
        grid_v[i, j].y -= dt * gravity
        if i < bound and grid_v[i, j].x < 0:
            grid_v[i, j].x = 0
        if i > n_grid - bound and grid_v[i, j].x > 0:
            grid_v[i, j].x = 0
        if j < bound and grid_v[i, j].y < 0:
            grid_v[i, j].y = 0
        if j > n_grid - bound and grid_v[i, j].y > 0:
            grid_v[i, j].y = 0
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        new_v = ti.Vector.zero(float, 2)
        new_C = ti.Matrix.zero(float, 2, 2)
        for i, j in ti.static(ti.ndrange(3, 3)):
            offset = ti.Vector([i, j])
            dpos = (offset - fx) * dx
            weight = w[i].x * w[j].y
            g_v = grid_v[base + offset]
            new_v += weight * g_v
            new_C += 4 * weight * g_v.outer_product(dpos) / dx**2
        v[p] = new_v
        x[p] += dt * v[p]
        if x[p].x < lower_bound or x[p].x > upper_bound:
            x[p].x = ti.max(lower_bound, ti.min(upper_bound, x[p].x))
        if x[p].y < lower_bound or x[p].y > upper_bound:
            x[p].y = ti.max(lower_bound, ti.min(upper_bound, x[p].y))
        J[p] *= 1 + dt * new_C.trace()
        C[p] = new_C

def initialize_particles(shape_type, n_particles):
    
    
    available_size = upper_bound - lower_bound
    
    
    num_parts = np.random.randint(1, 5)
    part_sizes = np.random.multinomial(n_particles, np.ones(num_parts)/num_parts)
    
    
    part_sizes[-1] = n_particles - sum(part_sizes[:-1])
    
    
    part_shapes = np.random.randint(0, 3, size=num_parts)  # 0=rect, 1=circle, 2=triangle
    part_vel_dirs = np.random.uniform(0, 2*np.pi, size=num_parts) 
    
    
    x_np = np.zeros((n_particles, 2))
    v_np = np.zeros((n_particles, 2))
    
    
    current_idx = 0
    
    for part in range(num_parts):
        size = part_sizes[part]
        shape_type = part_shapes[part]
        vel_dir = part_vel_dirs[part]
        
        
        part_vel = np.array([np.cos(vel_dir), np.sin(vel_dir)])
        
        if shape_type == 0:  
            
            max_width = available_size * 0.5
            max_height = available_size * 0.5
            width = np.random.uniform(0.1, max_width)
            height = np.random.uniform(0.1, max_height)
            
            
            min_x = lower_bound + width/2
            max_x = upper_bound - width/2
            min_y = lower_bound + height/2
            max_y = upper_bound - height/2
            center_x = np.random.uniform(min_x, max_x)
            center_y = np.random.uniform(min_y, max_y)
            
            
            part_pos = np.random.rand(size, 2).astype(np.float32) * [width, height] + [center_x - width/2, center_y - height/2]
            
        elif shape_type == 1:  
            
            max_radius = available_size * 0.25
            radius = np.random.uniform(0.05, max_radius)
            
            
            min_x = lower_bound + radius
            max_x = upper_bound - radius
            min_y = lower_bound + radius
            max_y = upper_bound - radius
            center = np.array([
                np.random.uniform(min_x, max_x),
                np.random.uniform(min_y, max_y)
            ])
            
            
            angles = np.random.rand(size) * 2 * np.pi
            rads = np.sqrt(np.random.rand(size)) * radius
            part_pos = np.column_stack([
                center[0] + rads * np.cos(angles),
                center[1] + rads * np.sin(angles)
            ])
            
        elif shape_type == 2:  
            
            max_base_size = available_size * 0.3
            base_size = np.random.uniform(0.1, max_base_size)
            
            
            safe_margin = base_size * 0.6
            min_x = lower_bound + safe_margin
            max_x = upper_bound - safe_margin
            min_y = lower_bound + safe_margin
            max_y = upper_bound - safe_margin
            center = np.array([
                np.random.uniform(min_x, max_x),
                np.random.uniform(min_y, max_y)
            ])
            
            
            angle = np.random.uniform(0, 2*np.pi)
            rot_matrix = np.array([
                [np.cos(angle), -np.sin(angle)],
                [np.sin(angle), np.cos(angle)]
            ])
            
            
            vertices = np.array([
                [-base_size/2, -base_size/2],
                [base_size/2, -base_size/2],
                [0, base_size/2]
            ])
            
            
            vertices = vertices @ rot_matrix.T + center
            
            
            u = np.random.rand(size, 2).astype(np.float32)
            mask = u.sum(1) > 1
            u[mask] = 1 - u[mask]
            part_pos = (u[:, 0:1] * vertices[0] + 
                       u[:, 1:2] * vertices[1] + 
                       (1 - u.sum(1, keepdims=True)) * vertices[2])
        
        
        x_np[current_idx:current_idx+size] = part_pos
        v_np[current_idx:current_idx+size] = part_vel
        current_idx += size
    
    
    x_np = np.clip(x_np, lower_bound, upper_bound)
    
    return x_np, v_np

@ti.kernel
def initialize_fields_values( C: ti.template(), J: ti.template()):
    
    J.fill(1)
    C.fill(0)

def run_simulation(sim_id, output_dir="Water_drop_simulations"):
    
    
    
    n_particles = np.random.randint(min_particles, max_particles + 1)
    
    # Randomly choose initialization shape
    shape_type = np.random.randint(0, 3)  # 0=rect, 1=circle, 2=triangle
    x, v, C, J, grid_v, grid_m = initialize_fields(n_particles)
    x_np, v_np = initialize_particles(shape_type, n_particles)
    x.from_numpy(x_np)
    v.from_numpy(v_np)
    initialize_fields_values(C, J)
    
    # Data storage
    positions = []
    velocities = []
    accelerations = []
    particle_types = PARTICLE_TYPE['Water'] * np.ones(n_particles)
    prev_vel = v.to_numpy()

    for frame in range(total_frames):
        for s in range(steps_per_frame):
            substep(x, v, C, J, grid_v, grid_m, n_particles)
        pos = x.to_numpy()
        vel = v.to_numpy()
        
        # Compute acceleration (finite difference)
        acc = (vel - prev_vel) / dt
        
        if np.isnan(vel).any():
            return None,None
        positions.append(pos)
        velocities.append(vel)
        accelerations.append(acc)
        
        prev_vel = vel.copy()
        

    
    # Save data in requested format
    trajectory = np.empty(2, dtype=object)
    trajectory[0] = np.array(positions)
    trajectory[1] = particle_types
    save_data[f'simulation_trajectory_{sim_id}'] = trajectory
    # save_data[f'simulation_trajectory_{sim_id}'] = np.array(
    #     [np.array(positions), particle_types], 
    #     dtype=object
    # )
      
    # Optional: Save video
    if sim_id < 10:  # Save videos for first 10 sims only to save space
        save_video(sim_id, positions,output_dir)

    return velocities,accelerations

def save_video(sim_id, positions, output_dir):
    output_path = os.path.join(output_dir,f"sim_{sim_id}.mp4")
    
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 96
    writer = imageio.get_writer(output_path, fps=fps)
    positions = np.array(positions)
    for idx in range(positions.shape[0]):
        gui.clear(0xFFFFFF)
        gui.circles(positions[idx], radius=1.5, color=0x068587)
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()

if __name__ == "__main__":
    num_simulations = 1200
    output_dir = 'Water_simulations'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    idx = 0
    # for i in range(num_simulations):
    while idx < num_simulations:
        print(f"Running simulation {idx}/{num_simulations}")
        velocities,accelerations = run_simulation(idx,output_dir)
        if velocities is None:
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        all_velocities.append(velocities)
        all_accelerations.append(accelerations)
        idx += 1

    # vel_stats = {
    #     "mean": np.mean(np.array([np.mean(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0)),
    #     "std": np.std(np.array([np.std(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0))
    # }
    
    # acc_stats = {
    #     "mean": np.mean(np.array([np.mean(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0)),
    #     "std": np.mean(np.array([np.std(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0))
    # }
    metadata = {
        "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
        "sequence_length": total_frames,
        "default_connectivity_radius": 0.015,
        "dim": dim,
        "dt": dt*steps_per_frame,
        "vel_mean": 0.0,
        "vel_std": 0.0,
        "acc_mean": 0.0,
        "acc_std": 0.0,
        "obstacles_configs": [] # [[[,],[,]],], 
    }
    
    # Save to files
    all_sim_ids = list(save_data.keys())
    train_ids = all_sim_ids[:1000]      
    valid_ids = all_sim_ids[1100:]  
    test_ids = all_sim_ids[1000:1100]   

    
    train_data = {k: save_data[k] for k in train_ids}
    valid_data = {k: save_data[k] for k in valid_ids}
    test_data = {k: save_data[k] for k in test_ids}

    
    np.savez(os.path.join(output_dir, 'train.npz'), **train_data)
    np.savez(os.path.join(output_dir, 'valid.npz'), **valid_data)
    np.savez(os.path.join(output_dir, 'test.npz'), **test_data)
    # np.savez(os.path.join(output_dir, f'all_elements.npz'), **save_data)
    with open(os.path.join(output_dir, f'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    print("All simulations completed!")