from mpm_engine.mpm_solver import MPMSolver
import taichi as ti
import numpy as np
import os
import imageio
from tqdm import tqdm
import json

ti.init(arch=ti.gpu, default_fp=ti.f32, debug=True)
SEED = 18
np.random.seed(SEED)   
PARTICLE_TYPE = {"Rigid":0, "Droplet":1, "Boundary":3, "Water":5, "Sand":6, "Goop":7}


n_grid = 128
res = (n_grid, n_grid)  
size = 1.0          
max_num_particles = 2**20  

dt = 2.5e-4
dim = 2
# Simulation parameters
max_particles = 4000   
min_particles = 1200  
lower_bound = 0.1
upper_bound = 0.9

save_data = {}
total_frames = 500
steps_per_frame = 10

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

class ParticleReducer:
    def __init__(self, target_particles=256, random_state=42):

        self.n_clusters = target_particles
        self.kmeans = KMeans(
            n_clusters=target_particles,
            init='k-means++',
            n_init=10,
            random_state=random_state
        )
        self.scaler = StandardScaler()
        
    def reduce_system(self, positions, velocities):
        
        
        features = np.hstack([positions, velocities])
        scaled_features = self.scaler.fit_transform(features)
        
        
        self.kmeans.fit(scaled_features)
        
        
        cluster_centers_scaled = self.kmeans.cluster_centers_
        cluster_centers = self.scaler.inverse_transform(cluster_centers_scaled)
        dim = positions.shape[1]
        new_positions = cluster_centers[:, :dim]
        
        
        new_velocities = np.zeros_like(new_positions)
        for c in range(self.n_clusters):
            cluster_mask = self.kmeans.labels_ == c
            if np.any(cluster_mask):
                new_velocities[c] = np.mean(velocities[cluster_mask], axis=0)
                
        return new_positions, new_velocities

def init_simulaor(new_pos_water, new_vel_water, new_pos_sand, new_vel_sand):
    ti.init(arch=ti.cpu,default_fp=ti.f32, debug=True, log_level=ti.ERROR)

    mpm = MPMSolver(res=res, size=size, max_num_particles=max_num_particles,quant=False,use_g2p2g=True,use_adaptive_dt=True,
                    E_scale=2, poisson_ratio=0.25, unbounded=True, support_plasticity=True)
    mpm.set_gravity((0, -9.8))
    # water_ratio = np.random.uniform(0.4, 0.6)
    # n_water = int(n_particles * water_ratio)
    # n_sand = n_particles - n_water

    # water_pos, _ = mpm.initialize_particles_random_shape(
    #     n_water, 
    #     lower_bound=lower_bound, 
    #     upper_bound=upper_bound, 
    #     material=mpm.material_water, 
    #     color=0x0066FF
    # )
    # mpm.initialize_particles_random_shape(
    #     n_sand, 
    #     lower_bound=lower_bound, 
    #     upper_bound=upper_bound, 
    #     material=mpm.material_sand, 
    #     color=0xda8d4a,
    #     existing_positions=water_pos
    # )
    mpm.add_particles(new_pos_water, 5, color=0x0066FF, velocity=new_vel_water)
    mpm.add_particles(new_pos_sand, 6, color=0xda8d4a, velocity=new_vel_sand)
    mpm.add_surface_collider(
        point=(0, lower_bound),
        normal=(0, 1),
        surface=mpm.surface_separate
    )
    mpm.add_surface_collider(
        point=(0, upper_bound),
        normal=(0, -1),
        surface=mpm.surface_separate
    )
    mpm.add_surface_collider(point=(lower_bound, 0), normal=(1, 0), surface=mpm.surface_separate)  
    mpm.add_surface_collider(point=(upper_bound, 0), normal=(-1, 0), surface=mpm.surface_separate)  
    
    return mpm

def run_simulation_mpm_thershold(next_position, next_velocity, types, step=50, output_dir="WaterSand_simulations"):
    
    # n_particles = np.random.randint(min_particles, max_particles + 1)
    # breakpoint()
    # types = data[1]
    # first_positions = data[0][0]
    # second_positions = data[0][1]
    unique_types = np.unique(types.cpu().numpy())

    type_to_next_position = {t: next_position[types == t] for t in unique_types}
    type_to_next_velocity = {t: next_velocity[types == t] for t in unique_types}
    
    next_position_water = type_to_next_position[5]
    next_position_sand = type_to_next_position[6]
    velocity_water = type_to_next_velocity[5]/(dt* steps_per_frame)
    velocity_sand = type_to_next_velocity[6]/(dt* steps_per_frame)

    # velocity_water = (second_positions_water - first_positions_water)/(dt* steps_per_frame)
    # velocity_sand = (second_positions_sand - first_positions_sand)/(dt* steps_per_frame)

    # n_particles_water = int(next_position_water.shape[0] / 1)
    # reducer_water = ParticleReducer(target_particles=n_particles_water)
    # new_pos_water, new_vel_water = reducer_water.reduce_system(first_positions_water, velocity_water)
    # new_pos_water = np.ascontiguousarray(new_pos_water)
    # new_vel_water = np.ascontiguousarray(new_vel_water)

    # n_particles_sand = int(first_positions_sand.shape[0] / 1)
    # reducer_sand = ParticleReducer(target_particles=n_particles_sand)
    # new_pos_sand, new_vel_sand = reducer_sand.reduce_system(first_positions_sand, velocity_sand)
    # new_pos_sand = np.ascontiguousarray(new_pos_sand)
    # new_vel_sand = np.ascontiguousarray(new_vel_sand)

    # vx = new_vel_water[:, 0]
    # vy = new_vel_water[:, 1]
    # new_vel_water = [vx, vy]
    # vx = new_vel_sand[:, 0]
    # vy = new_vel_sand[:, 1]
    # new_vel_sand = [vx, vy]

    mpm = init_simulaor(next_position_water, velocity_water, next_position_sand, velocity_sand)

    
    positions = []
    velocities = []
    accelerations = []
    particles_infos = []
    init_particles_info = mpm.particle_info()
    prev_vel = init_particles_info['velocity']
    particle_types = init_particles_info['material']
    
    # time_list = []
    # import time
    for frame in range(step):
        # start_time = time.time()
        for s in range(steps_per_frame):
            mpm.step(dt)
        particles_info = mpm.particle_info()
        pos = particles_info['position']
        vel = particles_info['velocity']

        vel = vel * dt* steps_per_frame

        
        acc = (vel - prev_vel)# / (steps_per_frame * dt)
        
        if np.isnan(vel).any():
            return None, None

        particles_infos.append(particles_info)
        positions.append(pos)
        velocities.append(vel)
        accelerations.append(acc)
        prev_vel = vel.copy()
    #     end_time = time.time()
    #     time_list.append(end_time - start_time)
    
    # breakpoint()
    # print(sum(time_list[1:]) / (len(time_list) - 1))

    del mpm
    # ti.reset()
    return positions, velocities
    
    trajectory = np.empty(2, dtype=object)
    trajectory[0] = np.array(positions)
    trajectory[1] = particle_types
    save_data[f'simulation_trajectory_{sim_id}'] = trajectory
      
    
    if sim_id < 10:  
        save_video(sim_id, particles_infos, output_dir)

    return velocities, accelerations

def save_video(sim_id, particles_infos, output_dir):
    output_path = os.path.join(output_dir, f"sim_{sim_id}.mp4")
    
    gui = ti.GUI(f"Simulation {sim_id}", background_color=0xFFFFFF, show_gui=False)
    fps = 60
    writer = imageio.get_writer(output_path, fps=fps)

    for particles_info in particles_infos:
        gui.clear(0xFFFFFF)
        gui.circles(particles_info['position'], radius=1.5, color=particles_info['color'])
        frame = gui.get_image()
        frame = np.rot90(frame, k=1)
        frame = (frame * 255).astype(np.uint8)
        writer.append_data(frame)
    
    writer.close()
    gui.close()

if __name__ == "__main__":
    # num_simulations = 1200
    output_dir = 'WaterSand_simulations_xl_k_means_1-1.75'
    os.makedirs(output_dir, exist_ok=True)
    all_velocities = []
    all_accelerations = []
    idx = 0

    Data_name = 'WaterSand_simulations_xl'
    split = 'test.npz'
    npz_filename = f''
    data = np.load(npz_filename, allow_pickle=True)
    num_simulations = len(data)

    for idx in range(num_simulations):
    # while idx < num_simulations:
        if split == 'test.npz':
            trajectory_id = 1000 + idx
        else:
            trajectory_id = idx
        trajectory_id = 1047
        print(f"Running simulation {idx}/{num_simulations}")
        velocities, accelerations = run_simulation(idx, data[f'simulation_trajectory_{trajectory_id}'], output_dir)
        if velocities is None:
            with open(os.path.join(output_dir, 'test_no.txt'), "a") as f:
                f.write(f"{idx},")
            print(f"Skipping simulation {idx} due to NaN values in velocities.")
            continue
        all_velocities.append(velocities)
        all_accelerations.append(accelerations)

    
    vel_stats = {
        "mean": np.mean(np.array([np.mean(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0)),
        "std": np.std(np.array([np.std(np.array(v), axis=(0, 1)) for v in all_velocities]), axis=(0))
    }
    acc_stats = {
        "mean": np.mean(np.array([np.mean(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0)),
        "std": np.mean(np.array([np.std(np.array(a), axis=(0, 1)) for a in all_accelerations]), axis=(0))
    }
    metadata = {
        "bounds": [[lower_bound, upper_bound], [lower_bound, upper_bound]],
        "sequence_length": total_frames,
        "default_connectivity_radius": 0.015,
        "dim": dim,
        "dt": dt*steps_per_frame,
        "vel_mean": vel_stats['mean'].tolist(),
        "vel_std": vel_stats['std'].tolist(),
        "acc_mean": acc_stats['mean'].tolist(),
        "acc_std": acc_stats['std'].tolist(),
        "obstacles_configs": []
    }
    
    # 
    # all_sim_ids = list(save_data.keys())
    # train_ids = all_sim_ids[:1000]      
    # valid_ids = all_sim_ids[1100:]      
    # test_ids = all_sim_ids[1000:1100]   

    # train_data = {k: save_data[k] for k in train_ids}
    # valid_data = {k: save_data[k] for k in valid_ids}
    # test_data = {k: save_data[k] for k in test_ids}

    # np.savez(os.path.join(output_dir, 'train.npz'), **train_data)
    # np.savez(os.path.join(output_dir, 'valid.npz'), **valid_data)
    # np.savez(os.path.join(output_dir, 'test.npz'), **test_data)
    
    np.savez(os.path.join(output_dir, split), **save_data)
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print("All simulations completed!")