import genesis as gs
import os
import openvdb as vdb
import time  
import cv2
import random
import numpy as np
import math
import gc
import torch




########################## Configuration Parameters ##########################
# Simulation parameter combinations
SIM_PARAMS = [
    {
        'horizon': 152,       # High precision settings - total simulation steps
        'dt': 1e-3,         # Time step - smaller is more accurate
        'substeps': 20,     # Number of substeps - more is more accurate
        'name': 'high_precision2'
    },
]

PARTICLE_SIZE = 0.005         # Particle size

# Fluid parameter combinations
FLUID_PARAMS = [
    {
        'mu': 0.01,          
        'gamma': 0.02,
        'name': 'mu001'
    },
]


# Create folder for saving videos
output_dir = "xxx"
os.makedirs(output_dir, exist_ok=True)


def generate_random_E():
    # Define the three ranges
    ranges = [
        (0.005, 0.05),
        (0.05, 0.5),
        (0.5, 1)
    ]
    
    # Choose a range with equal probability
    chosen_range = ranges[np.random.choice(len(ranges))]
    
    # Sample uniformly within the chosen range
    return np.random.uniform(*chosen_range)



split_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
num_videos_per_split = 1000

for video_i in range(split_id * num_videos_per_split, (split_id + 1) * num_videos_per_split):
    print(f"Starting simulation video{video_i}")
    for sim_params in SIM_PARAMS:
        for fluid_params in FLUID_PARAMS:
            print(f"\nStarting simulation: {sim_params['name']} - {fluid_params['name']}")
            
            # Record start time
            start_time = time.time()

            ########################## init ##########################
            gs.init(backend=gs.gpu)
            
            ########################## create a scene ##########################
            scene = gs.Scene(
                sim_options=gs.options.SimOptions(
                    dt=sim_params['dt'],
                    substeps=sim_params['substeps'],
                ),
                sph_options=gs.options.SPHOptions(
                    lower_bound=(-0.5, -0.5, 0.0),
                    upper_bound=(0.5, 0.5, 1),
                    particle_size=PARTICLE_SIZE,
                ),
                vis_options=gs.options.VisOptions(
                    visualize_sph_boundary=True,
                    show_world_frame = False,
                    segmentation_level='entity',
                    world_frame_size = 1.0,
                    show_link_frame  = False,
                    show_cameras     = False,
                    plane_reflection = False,
                ),
                show_viewer=False,
            )

            ########################## entities ##########################
            plane = scene.add_entity(
                morph=gs.morphs.Plane(),
            )

            random_height = 0.5
            random_color_1 = random.uniform(0, 1)
            random_color_2 = random.uniform(0, 1)
            random_E = generate_random_E()

            # Create water column
            liquid = scene.add_entity(
                material=gs.materials.SPH.Liquid(
                    mu=fluid_params['mu'] * random_E, 
                    gamma=fluid_params['gamma']
                ),
                morph=gs.morphs.Cylinder(
                    pos=(0.0, 0.0, 0.051 + 0.01 * random_height),
                    radius=0.05,
                    height=0.1,
                ),
                surface=gs.surfaces.Default(
                    color=(random_color_1, random_color_2, 0),
                    vis_mode="recon",
                ),
            )

            # Create camera
            random_cam_angle = random.uniform(0, 0.5 * math.pi)
            random_cam_height = random.uniform(-1, 1)
            random_cam_lookat_x = random.uniform(-1, 1)
            random_cam_lookat_y = random.uniform(-1, 1)
            random_cam_lookat_z = random.uniform(-1, 0.1)
            cam = scene.add_camera(
                res=(960, 640),
                pos=(1.5 * np.cos(random_cam_angle), 1.5 * np.sin(random_cam_angle), 1.0 + 0.5 * random_cam_height),
                lookat=(0.1 * random_cam_lookat_x, 0.1 * random_cam_lookat_y, 0.25 + 0.2 * random_cam_lookat_z),
                fov=30,
                GUI=False,
            )


            ########################## build ##########################
            scene.build()

            # Start recording
            cam.start_recording()

            # Collect segmentation frames
            seg_frames = []

            # Record simulation start time
            sim_start_time = time.time()
            
            for i in range(sim_params['horizon']):
                scene.step()
                # Get both RGB and segmentation simultaneously
                rgb, depth, seg, normal = cam.render(rgb=True, depth=True, segmentation=True, normal=False)
                # Only collect segmentation frames
                seg_frames.append(seg)
                
                # Print progress every 100 frames
                if i % 100 == 0:
                    elapsed = time.time() - sim_start_time
                    fps = (i + 1) / elapsed
                    print(f"Progress: {i+1}/{sim_params['horizon']} frames "
                        f"({(i+1)/sim_params['horizon']*100:.1f}%) "
                        f"FPS: {fps:.2f}")
            
            # Set fps value
            fps = 30


            # Generate filename containing parameters
            rgb_filename = f'viscosity_V_{random_E}_color_{random_color_1:.2f}-{random_color_2:.2f}-{0}_camangle_{random_cam_angle:.2f}_camheight_{random_cam_height:.2f}_camlookat_{random_cam_lookat_x:.2f}_{random_cam_lookat_y:.2f}_{random_cam_lookat_z:.2f}_height_{random_height:.2f}_rgb.mp4'
            seg_filename = f'viscosity_V_{random_E}_color_{random_color_1:.2f}-{random_color_2:.2f}-{0}_camangle_{random_cam_angle:.2f}_camheight_{random_cam_height:.2f}_camlookat_{random_cam_lookat_x:.2f}_{random_cam_lookat_y:.2f}_{random_cam_lookat_z:.2f}_height_{random_height:.2f}_seg.mp4'
            
            
            # Use camera native method to save RGB video
            cam.stop_recording(
                save_to_filename=os.path.join(output_dir, rgb_filename),
                fps=fps
            )
            
            # Use OpenCV to save segmentation video
            h, w = seg_frames[0].shape[:2]
            seg_writer = cv2.VideoWriter(
                os.path.join(output_dir, seg_filename),
                cv2.VideoWriter_fourcc(*'mp4v'),
                fps,
                (w, h)
            )
            
            for frame in seg_frames:
                # Convert segmentation to binary mask (0 and 1)
                mask = (frame > 0).astype(np.uint8) * 255  # Convert True/False to 255/0
                # Convert to 3 channels
                frame_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
                seg_writer.write(frame_color)
            seg_writer.release()


            # Free segmentation memory
            seg_frames.clear()

            # Free GPU memory
            scene.reset()  # Destroy the scene
            del scene, cam, plane, liquid  # Delete objects
            gs.destroy()  # Reset Genesis engine
            torch.cuda.empty_cache()  # Free GPU memory
            gc.collect()  # Collect garbage


            print(f"Simulation completed!")
            print(f"RGB video saved to: {os.path.join(output_dir, rgb_filename)}")
            print(f"Segmentation video saved to: {os.path.join(output_dir, seg_filename)}")


print("All simulations completed!")