from mpm_engine.mpm_solver import MPMSolver
import taichi as ti
import numpy as np
import os
import imageio
from tqdm import tqdm
import json

SEED = 18
np.random.seed(SEED)   
PARTICLE_TYPE = {"Rigid":0, "Droplet":1, "Boundary":3, "Water":5, "Sand":6, "Goop":7}


n_grid = 128
res = (n_grid, n_grid)  
size = 1.0          
max_num_particles = 2**20  

dt = 2.5e-4
dim = 2
# Simulation parameters
max_particles = 2000   
min_particles = 1000  
lower_bound = 0.1
upper_bound = 0.9

save_data = {}
total_frames = 500
steps_per_frame = 10


def init_simulaor(n_particles):
    ti.init(arch=ti.cpu,default_fp=ti.f32, debug=True, log_level=ti.ERROR)

    mpm = MPMSolver(res=res, size=size, max_num_particles=max_num_particles,quant=False,use_g2p2g=True,use_adaptive_dt=True,
                    E_scale=2, poisson_ratio=0.25, unbounded=True, support_plasticity=True)
    mpm.set_gravity((0, -9.8))
    water_ratio = np.random.uniform(0.4, 0.6)
    n_water = int(n_particles * water_ratio)
    n_sand = n_particles - n_water

    water_pos, _ = mpm.initialize_particles_random_shape(
        n_water, 
        lower_bound=lower_bound, 
        upper_bound=upper_bound, 
        material=mpm.material_water, 
        color=0x0066FF
    )
    mpm.initialize_particles_random_shape(
        n_sand, 
        lower_bound=lower_bound, 
        upper_bound=upper_bound, 
        material=mpm.material_sand, 
        color=0xda8d4a,
        existing_positions=water_pos
    )
    mpm.add_surface_collider(
        point=(0, lower_bound),
        normal=(0, 1),
        surface=mpm.surface_separate
    )
    mpm.add_surface_collider(
        point=(0, upper_bound),
        normal=(0, -1),
        surface=mpm.surface_separate
    )
    mpm.add_surface_collider(point=(lower_bound, 0), normal=(1, 0), surface=mpm.surface_separate)  
    mpm.add_surface_collider(point=(upper_bound, 0), normal=(-1, 0), surface=mpm.surface_separate)  
    
    return mpm

def run_simulation(sim_id, output_dir="WaterSand_simulations"):
    
    n_particles = np.random.randint(min_particles, max_particles + 1)
    
    mpm = init_simulaor(n_particles)

    
    positions = []
    velocities = []
    accelerations = []
    particles_infos = []
    init_particles_info = mpm.particle_info()
    prev_vel = init_particles_info['velocity']
    particle_types = init_particles_info['material']
    for frame in tqdm(range(total_frames)):
        for s in range(steps_per_frame):
            mpm.step(dt)
        particles_info = mpm.particle_info()
        pos = particles_info['position']
        vel = particles_info['velocity']
        
        
        acc = (vel - prev_vel) / (steps_per_frame * dt)
        
        if np.isnan(vel).any():
            return None, None

        particles_infos.append(particles_info)
        positions.append(pos)
        velocities.append(vel)
        accelerations.append(acc)
        prev_vel = vel.copy()
    
    del mpm
    ti.reset()
    
    trajectory = np.empty(2, dtype=object)
    trajectory[0] = np.array(positions)
    trajectory[1] = particle_types
    save_data[f'simulation_trajectory_{sim_id}'] = trajectory
      
    
    if sim_id < 10:  
        save_video(sim_id, particles_infos, output_dir)

    return velocities, accelerations

def save_video(sim_id, particles_infos, output_dir):
    output_path = os.path.join(output_dir, f"sim_{sim_id}.mp4")
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 60
    writer = imageio.get_writer(output_path, fps=fps)

    for particles_info in particles_infos:
        gui.clear(0xFFFFFF)
        gui.circles(particles_info['position'], radius=1.5, color=particles_info['color'])
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()
    gui.close()

if __name__ == "__main__":
    num_simulations = 1200
    output_dir = 'WaterSand_simulations'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    idx = 0
    while idx < num_simulations:
        print(f"Running simulation {idx}/{num_simulations}")
        velocities, accelerations = run_simulation(idx, output_dir)
        if velocities is None:
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        all_velocities.append(velocities)
        all_accelerations.append(accelerations)
        idx += 1

    
    vel_stats = {
        "mean": np.mean(np.array([np.mean(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0)),
        "std": np.std(np.array([np.std(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0))
    }
    acc_stats = {
        "mean": np.mean(np.array([np.mean(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0)),
        "std": np.mean(np.array([np.std(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0))
    }
    metadata = {
        "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
        "sequence_length": total_frames,
        "default_connectivity_radius": 0.015,
        "dim": dim,
        "dt": dt*steps_per_frame,
        "vel_mean": vel_stats['mean'].tolist(),
        "vel_std": vel_stats['std'].tolist(),
        "acc_mean": acc_stats['mean'].tolist(),
        "acc_std": acc_stats['std'].tolist(),
        "obstacles_configs": []
    }
    
    
    all_sim_ids = list(save_data.keys())
    train_ids = all_sim_ids[:1000]      
    valid_ids = all_sim_ids[1100:]      
    test_ids = all_sim_ids[1000:1100]   

    train_data = {k: save_data[k] for k in train_ids}
    valid_data = {k: save_data[k] for k in valid_ids}
    test_data = {k: save_data[k] for k in test_ids}

    np.savez(os.path.join(output_dir, 'train.npz'), **train_data)
    np.savez(os.path.join(output_dir, 'valid.npz'), **valid_data)
    np.savez(os.path.join(output_dir, 'test.npz'), **test_data)
    
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print("All simulations completed!")