import genesis as gs
import os
import cv2
import numpy as np
import ipdb
import json
import gc
import torch
import random
import math


output_dir = 'xxx'
os.makedirs(output_dir, exist_ok=True)


parameter_sets = [
    {'E': 1e3, 'nu': 0.3},
]

params = parameter_sets[0]

MAX_E = 1e5
MIN_E = 1e3
MUN_VIDEOS = 10000

split_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
num_videos_per_split = 2000

# start simulation for each parameter group
for video_i in range(split_id * num_videos_per_split, (split_id + 1) * num_videos_per_split):
    print(f"start simulation video{video_i}")

    ########################## init ##########################
    gs.init()
    
    ########################## create a scene ##########################
    scene = gs.Scene(
        sim_options=gs.options.SimOptions(
            dt=3e-3,
            substeps=10,
        ),
        mpm_options=gs.options.MPMOptions(
            lower_bound=(-0.5, -1.0, 0.0),
            upper_bound=(0.5, 1.0, 1),
        ),
        vis_options=gs.options.VisOptions(
            visualize_mpm_boundary=True,
            show_world_frame = False,
            segmentation_level='entity',
            world_frame_size = 1.0,
            show_link_frame  = False,
            show_cameras     = False,
            plane_reflection = False,
        ),
        viewer_options=gs.options.ViewerOptions(
            camera_fov=60,
            res=(960, 640),
        ),
        show_viewer=False,
    )

    ########################## entities ##########################
    plane = scene.add_entity(
        morph=gs.morphs.Plane(),
    )

    random_shape = random.choice([0])
    random_height = random.uniform(0, 1)
    random_color_1 = random.uniform(0, 1)
    random_color_2 = random.uniform(0, 1)
    random_color_3 = random.uniform(0, 1)
    random_E = random.uniform(0, 1)

    if random_shape == 0:
        obj_elastic = scene.add_entity(
            material=gs.materials.MPM.Elastic(
                E=params['E'] * (1 + random_E * 99),
                nu=params['nu'],
            ),
            morph=gs.morphs.Sphere(
                pos=(0.0, 0.0, 0.25 + 0.15 * random_height),
                radius=0.1,
            ),
            surface=gs.surfaces.Default(
                color=(random_color_1, random_color_2, 0),
                vis_mode="visual",
            ),
        )
    else:
        assert 0 == 1

    

    # create the camera
    random_cam_angle = random.uniform(0, 0.5 * math.pi)
    random_cam_height = random.uniform(-1, 1)
    random_cam_lookat_x = random.uniform(-1, 1)
    random_cam_lookat_y = random.uniform(-1, 1)
    random_cam_lookat_z = random.uniform(-1, 0.1)
    cam = scene.add_camera(
        res=(960, 640),
        pos=(1.5 * np.cos(random_cam_angle), 1.5 * np.sin(random_cam_angle), 1.0 + 0.5 * random_cam_height),
        lookat=(0.1 * random_cam_lookat_x, 0.1 * random_cam_lookat_y, 0.25 + 0.2 * random_cam_lookat_z),
        fov=60,
        GUI=False,
    )

    ########################## build ##########################
    scene.build()

    # Start RGB recording
    cam.start_recording()
    
    # Collect segmentation frames
    seg_frames = []

    # Simulate for 200 frames
    horizon = 200
    frame_i_vel = {}
    for i in range(horizon):
        scene.step()
        # Get both RGB and segmentation simultaneously
        rgb, depth, seg, normal = cam.render(rgb=True, depth=True, segmentation=True, normal=False)
        # Only collect segmentation frames
        seg_frames.append(seg)
        velocities = obj_elastic.get_state().vel
        ave_vel = velocities[:,2].mean().cpu().numpy().tolist()
        frame_i_vel[i] = ave_vel


    # Set fps value
    fps = 30

    # Generate filenames containing parameters
    rgb_filename = f'bounce_ball_E_{random_E}_color_{random_color_1:.2f}-{random_color_2:.2f}-{random_color_3:.2f}_camangle_{random_cam_angle:.2f}_camheight_{random_cam_height:.2f}_camlookat_{random_cam_lookat_x:.2f}_{random_cam_lookat_y:.2f}_{random_cam_lookat_z:.2f}_height_{random_height:.2f}_shape_{random_shape}_rgb.mp4'
    seg_filename = f'bounce_ball_E_{random_E}_color_{random_color_1:.2f}-{random_color_2:.2f}-{random_color_3:.2f}_camangle_{random_cam_angle:.2f}_camheight_{random_cam_height:.2f}_camlookat_{random_cam_lookat_x:.2f}_{random_cam_lookat_y:.2f}_{random_cam_lookat_z:.2f}_height_{random_height:.2f}_shape_{random_shape}_seg.mp4'
    vel_filename = f'bounce_ball_E_{random_E}_color_{random_color_1:.2f}-{random_color_2:.2f}-{random_color_3:.2f}_camangle_{random_cam_angle:.2f}_camheight_{random_cam_height:.2f}_camlookat_{random_cam_lookat_x:.2f}_{random_cam_lookat_y:.2f}_{random_cam_lookat_z:.2f}_height_{random_height:.2f}_shape_{random_shape}_vel.json'
    
    # save the velocities
    with open(os.path.join(output_dir, vel_filename), 'w') as dump_f:
        json.dump(frame_i_vel, dump_f)
    
    # Use camera native method to save RGB video
    cam.stop_recording(
        save_to_filename=os.path.join(output_dir, rgb_filename),
        fps=fps
    )
    
    # Use OpenCV to save segmentation video
    h, w = seg_frames[0].shape[:2]
    seg_writer = cv2.VideoWriter(
        os.path.join(output_dir, seg_filename),
        cv2.VideoWriter_fourcc(*'mp4v'),
        fps,
        (w, h)
    )
    
    for frame in seg_frames:
        # Convert segmentation to binary mask (0 and 1)
        mask = (frame > 0).astype(np.uint8) * 255  # Convert True/False to 255/0
        # Convert to 3 channels
        frame_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
        seg_writer.write(frame_color)
    seg_writer.release()


    # Free segmentation memory
    seg_frames.clear()

    # Free GPU memory
    scene.reset()  # Destroy the scene
    del scene, cam, obj_elastic  # Delete objects
    gs.destroy()  # Reset Genesis engine
    torch.cuda.empty_cache()  # Free GPU memory
    gc.collect()  # Collect garbage


    print(f"Simulation completed!")
    print(f"RGB video saved to: {os.path.join(output_dir, rgb_filename)}")
    print(f"Segmentation video saved to: {os.path.join(output_dir, seg_filename)}")

print("All simulations completed!")
