import taichi as ti
import imageio  
import numpy as np
import json
import os
from math import sqrt, cos, sin, pi


PARTICLE_TYPE={"Rigid":0, "Droplet": 1, "Boundary": 3, "Water": 5, "Sand": 6, "Goop": 7}

# Initialize Taichi
ti.init(arch=ti.gpu, default_fp=ti.f32)

SEED = 42
np.random.seed(SEED)

# Simulation parameters
max_particles = 20000 
min_particles = 8192  
n_grid = 64
dx, inv_dx = 1 / n_grid, float(n_grid)
dt = 2.5e-4
dim = 3

p_vol = (dx * 0.5) ** 2
rho = 400
mass = p_vol * rho
E, nu = 4.0e5, 0.25  
mu, lambda_ = E / (2 * (1 + nu)), E * nu / ((1 + nu) * (1 - 2 * nu))  
mu_b = 1.0  


h0, h1, h2, h3 = 35, 9, 0.2, 10
pi = 3.14159265358979

gravity = ti.Vector([0, -9.8, 0])
neighbour = (3,) * dim

lower_bound = 0.125
upper_bound = 0.875
bound = int(0.15 * n_grid)
E = 400

save_data = {}
velocities = []
accelerations = []
total_frames = 300
steps_per_frame = 10

def initialize_fields(n_particles):
    
    x = ti.Vector.field(dim, dtype=ti.f32, shape=n_particles)  
    v = ti.Vector.field(dim, dtype=ti.f32, shape=n_particles) 
    C = ti.Matrix.field(dim, dim, dtype=ti.f32, shape=n_particles)  
    F = ti.Matrix.field(dim, dim, dtype=ti.f32, shape=n_particles) 
    phi = ti.field(dtype=ti.f32, shape=n_particles)  
    c_C0 = ti.field(dtype=ti.f32, shape=n_particles)  
    vc = ti.field(dtype=ti.f32, shape=n_particles)  
    alpha = ti.field(dtype=ti.f32, shape=n_particles)  
    q = ti.field(dtype=ti.f32, shape=n_particles) 
    state = ti.field(dtype=int, shape=n_particles)  
    color = ti.field(dtype=int, shape=n_particles)  
    
 
    grid_v = ti.Vector.field(dim, dtype=ti.f32, shape=(n_grid,)*dim)  
    grid_m = ti.field(dtype=ti.f32, shape=(n_grid,)*dim)  
    grid_f = ti.Vector.field(dim, dtype=ti.f32, shape=(n_grid, )*dim)  
    
    return x, v, C, F, phi, c_C0, vc, alpha, q, state, color, grid_v, grid_m, grid_f
    

@ti.func
def h_s(z):
    ret = 0.0
    if z < 0: 
        ret = 1
    if z > 1: 
        ret = 0
    ret = 1 - 10 * (z ** 3) + 15 * (z ** 4) - 6 * (z ** 5)
    return ret

@ti.func
def project(e0: ti.template(), p: int, vc: ti.template(), phi: ti.template(), c_C0: ti.template(), alpha: ti.template(), 
            dim: int, lambda_: float, mu: float, state: ti.template()):
    e = e0 + vc[p] / dim * ti.Matrix.identity(ti.f32, 3)  
    e += (c_C0[p] * (1.0 - phi[p])) / (dim * alpha[p]) * ti.Matrix.identity(float, 3)  
    ehat = e - e.trace() / dim * ti.Matrix.identity(ti.f32, 3)  
    
    
    Fnorm = ti.sqrt(ehat[0, 0]**2 + ehat[1, 1]**2 + ehat[2, 2]**2)
    yp = Fnorm + (dim * lambda_ + 2 * mu) / (2 * mu) * e.trace() * alpha[p]
    
    new_e = ti.Matrix.zero(ti.f32, 3, 3)  
    delta_q = 0.0
    
    if Fnorm <= 0 or e.trace() > 0:
        new_e = ti.Matrix.zero(ti.f32, 3, 3)  
        delta_q = ti.sqrt(e[0, 0]**2 + e[1, 1]**2 + e[2, 2]**2)  
        state[p] = 0
    elif yp <= 0:
        new_e = e0
        delta_q = 0
        state[p] = 1
    else:
        new_e = e - yp / Fnorm * ehat
        delta_q = yp
        state[p] = 2

    return new_e, delta_q

@ti.func
def hardening(dq: float, p: int, q: ti.template(), alpha: ti.template(), h0: float, h1: float, h2: float, h3: float, pi: float):
    q[p] += dq
    phi_angle = h0 + (h1 * q[p] - h3) * ti.exp(-h2 * q[p])
    phi_angle = phi_angle / 180 * pi
    sin_phi = ti.sin(phi_angle)
    alpha[p] = ti.sqrt(2 / 3) * (2 * sin_phi) / (3 - sin_phi)


@ti.kernel
def substep(
    x: ti.template(),
    v: ti.template(),
    C: ti.template(),
    F: ti.template(),
    phi: ti.template(),
    c_C0: ti.template(),
    vc: ti.template(),
    alpha: ti.template(),
    q: ti.template(),
    state: ti.template(),
    grid_v: ti.template(),
    grid_m: ti.template(),
    grid_f: ti.template(),
    n_particles: int
):

    for I in ti.grouped(grid_m):
        grid_v[I] = ti.zero(grid_v[I])
        grid_m[I] = 0
        grid_f[I] = ti.zero(grid_f[I])
    ti.loop_config(block_dim=n_grid)

    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        
        U, sig, V = ti.svd(F[p])
        inv_sig = sig.inverse()
        e = ti.Matrix([[ti.log(sig[0, 0]), 0, 0], [0, ti.log(sig[1, 1]), 0], [0, 0, ti.log(sig[2, 2])]])
        stress = U @ (2 * mu * inv_sig @ e + lambda_ * e.trace() * inv_sig) @ V.transpose()
        stress = (-p_vol * 8 * inv_dx * inv_dx) * stress @ F[p].transpose()  # 8 in 3D
        
        affine = mass * C[p]
        for offset in ti.static(ti.grouped(ti.ndrange(*neighbour))): # transmit particle information to the surrounding grid
            dpos = (offset - fx) * dx
            weight = 1.0
            for i in ti.static(range(dim)):
                weight *= w[offset[i]][i]
            grid_v[base + offset] += weight * (mass * v[p] + affine @ dpos)   # project momentum onto the grid
            grid_m[base + offset] += weight * mass  
            grid_f[base + offset] += weight * stress @ dpos
        
    # for I in ti.grouped(grid_m):
    #     if grid_m[I] > 0:
    #         grid_v[I] /= grid_m[I]         # calculate velocity using momentum
    #     grid_v[I][1] -= dt * gravity      # aplly gravity
    #     cond = (I < bound) & (grid_v[I] < 0) | (I > n_grid - bound) & (grid_v[I] > 0)  # boundary conditions
    #     grid_v[I] = ti.select(cond, 0, grid_v[I])
    # ti.loop_config(block_dim=n_grid)
    for i, j, k in grid_m:
        if grid_m[i, j, k] > 0:
            grid_v[i, j, k] = (1 / grid_m[i, j, k]) * grid_v[i, j, k]
            grid_v[i, j, k] += dt * (gravity + grid_f[i, j, k] / grid_m[i, j, k])

            # Boundary handling
            normal = ti.Vector.zero(ti.f32, 3)
            if i < bound and grid_v[i, j, k][0] < 0: normal = [1, 0, 0]
            if i > n_grid - bound and grid_v[i, j, k][0] > 0: normal = [-1, 0, 0]
            if j < bound and grid_v[i, j, k][1] < 0: normal = [0, 1, 0]
            if j > n_grid - bound and grid_v[i, j, k][1] > 0: normal = [0, -1, 0]
            if k < bound and grid_v[i, j, k][2] < 0: normal = [0, 0, 1]
            if k > n_grid - bound and grid_v[i, j, k][2] > 0: normal = [0, 0, -1]
            
            if not (normal[0] == 0 and normal[1] == 0 and normal[2] == 0):
                s = grid_v[i, j, k].dot(normal)
                if s <= 0:
                    v_normal = s * normal
                    v_tangent = grid_v[i, j, k] - v_normal
                    vt = v_tangent.norm()
                    if vt > 1e-12: 
                        grid_v[i, j, k] = v_tangent - (vt if vt < -mu_b * s else -mu_b * s) * (v_tangent / vt)


    
    for p in x:
        Xp = x[p] / dx
        base = int(Xp - 0.5)
        fx = Xp - base
        w = [0.5 * (1.5 - fx) ** 2, 0.75 - (fx - 1) ** 2, 0.5 * (fx - 0.5) ** 2]
        
        new_v = ti.zero(v[p])         # new velocity
        new_C = ti.zero(C[p])  
        for i, j, k in ti.static(ti.ndrange(3, 3, 3)):
            dpos = ti.Vector([i, j, k]).cast(ti.f32) - fx
            g_v = grid_v[base + ti.Vector([i, j, k])]
            weight = w[i][0] * w[j][1] * w[k][2]
            new_v += weight * g_v
            new_C += 4 * inv_dx * weight * g_v.outer_product(dpos)

        F[p] = (ti.Matrix.identity(ti.f32, 3) + dt * new_C) @ F[p]
        v[p], C[p] = new_v, new_C
        x[p] += dt * v[p]
        for d in ti.static(range(3)):
            if x[p][d] < lower_bound or x[p][d] > upper_bound:
                x[p][d] = ti.max(lower_bound, ti.min(upper_bound, x[p][d]))
                
        U, sig, V = ti.svd(F[p])
        e = ti.Matrix([[ti.log(sig[0, 0]), 0, 0], [0, ti.log(sig[1, 1]), 0], [0, 0, ti.log(sig[2, 2])]])
        new_e, dq = project(e, p, vc, phi, c_C0, alpha, dim, lambda_, mu, state)
        hardening(dq, p, q, alpha, h0, h1, h2, h3, pi)
        new_F = U @ ti.Matrix([[ti.exp(new_e[0, 0]), 0, 0], [0, ti.exp(new_e[1, 1]), 0], [0, 0, ti.exp(new_e[2, 2])]]) @ V.transpose()
        vc[p] += -ti.log(new_F.determinant()) + ti.log(F[p].determinant())
        F[p] = new_F

def T(a):
    if dim == 2:
        return a

    phi, theta = np.radians(28), np.radians(32)

    a = a - 0.5
    x, y, z = a[:, 0], a[:, 1], a[:, 2]
    cp, sp = np.cos(phi), np.sin(phi)
    ct, st = np.cos(theta), np.sin(theta)
    x, z = x * cp + z * sp, z * cp - x * sp
    u, v = x, y * ct + z * st
    return np.array([u, v]).swapaxes(0, 1) + 0.5
    
def initialize_particles(shape_type, n_particles):
    
    available_size = upper_bound - lower_bound
    
    
    num_parts = np.random.randint(1, 5)
    part_sizes = np.random.multinomial(n_particles, np.ones(num_parts)/num_parts)
    
    
    part_sizes[-1] = n_particles - sum(part_sizes[:-1])
    
    
    part_shapes = np.random.randint(0, 4, size=num_parts)  
    
    part_vel_azimuth = np.random.uniform(0, 2*np.pi, size=num_parts)  
    part_vel_polar = np.random.uniform(0, np.pi, size=num_parts)      
    
    
    x_np = np.zeros((n_particles, 3))
    v_np = np.zeros((n_particles, 3))
    
    
    current_idx = 0
    
    for part in range(num_parts):
        size = part_sizes[part]
        shape_type = part_shapes[part]
        azimuth = part_vel_azimuth[part]
        polar = part_vel_polar[part]
        
        
        part_vel = np.array([
            np.sin(polar) * np.cos(azimuth),
            np.sin(polar) * np.sin(azimuth),
            np.cos(polar)
        ])
        
        if shape_type == 0: 
           
            max_size = available_size * 0.5
            width = np.random.uniform(0.1, max_size)
            height = np.random.uniform(0.1, max_size)
            depth = np.random.uniform(0.1, max_size)
            
            
            min_x = lower_bound + width/2
            max_x = upper_bound - width/2
            min_y = lower_bound + height/2
            max_y = upper_bound - height/2
            min_z = lower_bound + depth/2
            max_z = upper_bound - depth/2
            
            center_x = np.random.uniform(min_x, max_x)
            center_y = np.random.uniform(min_y, max_y)
            center_z = np.random.uniform(min_z, max_z)
            
            
            part_pos = np.random.rand(size, 3).astype(np.float32) * [width, height, depth] + \
                       [center_x - width/2, center_y - height/2, center_z - depth/2]
            
        elif shape_type == 1:  
            
            max_radius = available_size * 0.25
            radius = np.random.uniform(0.05, max_radius)
            
            
            min_x = lower_bound + radius
            max_x = upper_bound - radius
            min_y = lower_bound + radius
            max_y = upper_bound - radius
            min_z = lower_bound + radius
            max_z = upper_bound - radius
            
            center = np.array([
                np.random.uniform(min_x, max_x),
                np.random.uniform(min_y, max_y),
                np.random.uniform(min_z, max_z)
            ])
            
            
            angles_azimuth = np.random.rand(size) * 2 * np.pi
            angles_polar = np.arccos(2 * np.random.rand(size) - 1)
            rads = np.cbrt(np.random.rand(size)) * radius 
            
            part_pos = np.column_stack([
                center[0] + rads * np.sin(angles_polar) * np.cos(angles_azimuth),
                center[1] + rads * np.sin(angles_polar) * np.sin(angles_azimuth),
                center[2] + rads * np.cos(angles_polar)
            ])
            
        elif shape_type == 2:  
            
            max_radius = available_size * 0.2
            max_height = available_size * 0.4
            radius = np.random.uniform(0.05, max_radius)
            height = np.random.uniform(0.1, max_height)
            
            
            min_x = lower_bound + radius
            max_x = upper_bound - radius
            min_y = lower_bound + radius
            max_y = upper_bound - radius
            min_z = lower_bound + height/2
            max_z = upper_bound - height/2
            
            center = np.array([
                np.random.uniform(min_x, max_x),
                np.random.uniform(min_y, max_y),
                np.random.uniform(min_z, max_z)
            ])
            
            
            axis_azimuth = np.random.uniform(0, 2*np.pi)
            axis_polar = np.random.uniform(0, np.pi)
            axis = np.array([
                np.sin(axis_polar) * np.cos(axis_azimuth),
                np.sin(axis_polar) * np.sin(axis_azimuth),
                np.cos(axis_polar)
            ])
            
            
            angles = np.random.rand(size) * 2 * np.pi
            rads = np.sqrt(np.random.rand(size)) * radius 
            h = (np.random.rand(size) - 0.5) * height
            
            
            
            if np.abs(axis[0]) > 0.1 or np.abs(axis[1]) > 0.1:
                perp_vec1 = np.array([-axis[1], axis[0], 0])
            else:
                perp_vec1 = np.array([0, -axis[2], axis[1]])
            perp_vec1 = perp_vec1 / np.linalg.norm(perp_vec1)
            perp_vec2 = np.cross(axis, perp_vec1)
            
            
            part_pos = np.column_stack([
                center[0] + rads * (np.cos(angles) * perp_vec1[0] + np.sin(angles) * perp_vec2[0]) + h * axis[0],
                center[1] + rads * (np.cos(angles) * perp_vec1[1] + np.sin(angles) * perp_vec2[1]) + h * axis[1],
                center[2] + rads * (np.cos(angles) * perp_vec1[2] + np.sin(angles) * perp_vec2[2]) + h * axis[2]
            ])
            
        elif shape_type == 3:  
            
            max_size = available_size * 0.3
            base_size = np.random.uniform(0.1, max_size)
            
            
            safe_margin = base_size * 0.6
            min_x = lower_bound + safe_margin
            max_x = upper_bound - safe_margin
            min_y = lower_bound + safe_margin
            max_y = upper_bound - safe_margin
            min_z = lower_bound + safe_margin
            max_z = upper_bound - safe_margin
            
            center = np.array([
                np.random.uniform(min_x, max_x),
                np.random.uniform(min_y, max_y),
                np.random.uniform(min_z, max_z)
            ])
            
            
            vertices = np.array([
                [0, 0, 0],
                [base_size, 0, 0],
                [base_size/2, base_size*np.sqrt(3)/2, 0],
                [base_size/2, base_size*np.sqrt(3)/6, base_size*np.sqrt(6)/3]
            ])
            
            
            rot_azimuth = np.random.uniform(0, 2*np.pi)
            rot_polar = np.random.uniform(0, np.pi)
            axis = np.array([
                np.sin(rot_polar) * np.cos(rot_azimuth),
                np.sin(rot_polar) * np.sin(rot_azimuth),
                np.cos(rot_polar)
            ])
            rot_angle = np.random.uniform(0, 2*np.pi)
            
            
            K = np.array([
                [0, -axis[2], axis[1]],
                [axis[2], 0, -axis[0]],
                [-axis[1], axis[0], 0]
            ])
            R = np.eye(3) + np.sin(rot_angle) * K + (1 - np.cos(rot_angle)) * K @ K
            
            
            vertices = (R @ vertices.T).T + center
            
            
            u = np.random.rand(size, 3)
            mask = u.sum(1) > 1
            while np.any(mask):
                u[mask] = np.random.rand(np.sum(mask), 3)
                mask = u.sum(1) > 1
            
            part_pos = (u[:, 0:1] * vertices[0] + 
                       u[:, 1:2] * vertices[1] + 
                       u[:, 2:3] * vertices[2] +
                       (1 - u.sum(1, keepdims=True)) * vertices[3])
        
        
        x_np[current_idx:current_idx+size] = part_pos
        v_np[current_idx:current_idx+size] = part_vel
        current_idx += size
    
    
    x_np = np.clip(x_np, lower_bound, upper_bound)
    
    return x_np, v_np

@ti.kernel
def initialize_fields_values(F: ti.template(), c_C0: ti.template(), alpha: ti.template(), n_particles: int):
    for i in range(n_particles):
        F[i] = ti.Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        c_C0[i] = 0.1
        alpha[i] = 0.2

def update_color(color, phi, n_particles):
    @ti.kernel
    def _update_color(color: ti.template(), phi: ti.template(), n_particles: int):
        for i in range(n_particles):
            t = phi[i]
            r = int((0.521 * (1 - t) + 0.318 * t) * 0xFF)
            g = int((0.368 * (1 - t) + 0.223 * t) * 0xFF)
            b = int((0.259 * (1 - t) + 0.157 * t) * 0xFF)
            color[i] = (r << 16) + (g << 8) + b
    _update_color(color, phi, n_particles)
    

def run_simulation(sim_id, output_dir="Water_drop_simulations"):
    
    
    n_particles = np.random.randint(min_particles, max_particles + 1)
    
    x, v, C, F, phi, c_C0, vc, alpha, q, state, color, grid_v, grid_m, grid_f = initialize_fields(n_particles)
    
    shape_type = np.random.randint(0, 4)  
    x_np, v_np = initialize_particles(shape_type, n_particles)
    x.from_numpy(x_np.astype(np.float32))
    v.from_numpy(v_np.astype(np.float32))
    initialize_fields_values(F, c_C0, alpha, n_particles)
    
    
    # Data storage
    positions = []
    velocities = []
    accelerations = []
    particle_types = PARTICLE_TYPE['Sand'] * np.ones(n_particles)
    prev_vel = v.to_numpy()

    for frame in range(total_frames):
        for s in range(steps_per_frame):
            substep(x, v, C, F, phi, c_C0, vc, alpha, q, state, grid_v, grid_m, grid_f, n_particles)
        pos = x.to_numpy()
        vel = v.to_numpy()
        
        # Compute acceleration (finite difference)
        acc = (vel - prev_vel) / (steps_per_frame * dt)
        
        if np.isnan(vel).any():
            return None, None, None
        positions.append(pos)
        velocities.append(vel)
        accelerations.append(acc)
        
        prev_vel = vel.copy()
        

    
    # Save data in requested format
    trajectory = np.empty(2, dtype=object)
    trajectory[0] = np.array(positions)
    trajectory[1] = particle_types
    save_data[f'simulation_trajectory_{sim_id}'] = trajectory
    # save_data[f'simulation_trajectory_{sim_id}'] = np.array(
    #     [np.array(positions), particle_types], 
    #     dtype=object
    # )
      
    # Optional: Save video
    if sim_id < 10:  # Save videos for first 10 sims only to save space
        save_video(sim_id, positions,output_dir, color, phi, n_particles)

    return velocities,accelerations, n_particles

def save_video(sim_id, positions, output_dir, color, phi, n_particles):
    output_path = os.path.join(output_dir, f"sim_{sim_id}.gif")
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 96
    writer = imageio.get_writer(output_path, fps=fps)
    positions = np.array(positions)

    update_color(color, phi, n_particles)
    for idx in range(positions.shape[0]):
        gui.clear(0xFFFFFF)
        gui.circles(T(positions[idx]), radius=1.5, color=color.to_numpy())
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()

if __name__ == "__main__":
    num_simulations = 1200
    output_dir = 'Sand3d_simulations'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    idx = 0
    # for i in range(num_simulations):
    while idx < num_simulations:
        print(f"Running simulation {idx}/{num_simulations}")
        velocities, accelerations, n_particles  = run_simulation(idx,output_dir)
        if velocities is None:
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        all_velocities.append(velocities)
        all_accelerations.append(accelerations)
        idx += 1

    vel_stats = {
        "mean": np.mean(np.array([np.mean(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0)),
        "std": np.std(np.array([np.std(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0))
    }
    
    acc_stats = {
        "mean": np.mean(np.array([np.mean(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0)),
        "std": np.mean(np.array([np.std(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0))
    }
    metadata = {
        "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
        "sequence_length": total_frames,
        "default_connectivity_radius": 0.015,
        "dim": dim,
        "dt": dt*steps_per_frame,
        "vel_mean": vel_stats['mean'].tolist(),
        "vel_std": vel_stats['std'].tolist(),
        "acc_mean": acc_stats['mean'].tolist(),
        "acc_std": acc_stats['std'].tolist(),
        "obstacles_configs": [] # [[[,],[,]],], 
    }
    
    # Save to files
    all_sim_ids = list(save_data.keys())
    train_ids = all_sim_ids[:1000]      
    valid_ids = all_sim_ids[1100:]  
    test_ids = all_sim_ids[1000:1100]   

    
    train_data = {k: save_data[k] for k in train_ids}
    valid_data = {k: save_data[k] for k in valid_ids}
    test_data = {k: save_data[k] for k in test_ids}

    
    np.savez(os.path.join(output_dir, 'train.npz'), **train_data)
    np.savez(os.path.join(output_dir, 'valid.npz'), **valid_data)
    np.savez(os.path.join(output_dir, 'test.npz'), **test_data)
    # np.savez(os.path.join(output_dir, f'all_elements.npz'), **save_data)
    with open(os.path.join(output_dir, f'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    print("All simulations completed!")