import genesis as gs
import os
import cv2
import numpy as np
import json
import gc
import torch
import random
import math
import pdb



# Create output directory
output_dir = 'xxx'
os.makedirs(output_dir, exist_ok=True)

NUM_VIDEOS = 2500  


# Define available shapes for the sliding object
AVAILABLE_SHAPES = [
    "box",        # Standard cube/box
    "sphere",     # Sphere
    "cylinder",   # Cylinder
    "capsule",    # Capsule (cylinder with hemispherical ends)
    "cone"        # Cone
]

def create_simulation(video_id):
    """
    Creates and runs a friction coefficient simulation experiment
    
    Args:
        video_id: Video ID used to generate unique filename
    """
    print(f"Starting simulation video{video_id}")

    ########################## Initialize Genesis ##########################
    gs.init()
    
    ########################## Create Scene ##########################
    scene = gs.Scene(
        sim_options=gs.options.SimOptions(
            dt=3e-3,
            substeps=10,
            gravity=(0, 0, -9.8),  
        ),
        vis_options=gs.options.VisOptions(
            visualize_mpm_boundary=True,
            show_world_frame=False,
            segmentation_level='entity',
            world_frame_size=1.0,
            show_link_frame=False,
            show_cameras=False,
            plane_reflection=False,  
        ),
        viewer_options=gs.options.ViewerOptions(
            camera_fov=45,
            res=(960, 640),
        ),
        rigid_options=gs.options.RigidOptions(
            box_box_detection=True,
        ),
        show_viewer=False,
    )

    ########################## Random Parameter Settings ##########################
    shape_index = 0 
    random_shape = AVAILABLE_SHAPES[shape_index]
    

    random_size = 0.1  # Fixed size
    random_color = (random.uniform(0, 1), random.uniform(0, 1), 0)


    if random.uniform(0, 1) < 0.5:
        vel_x = random.uniform(-1.0, 1.0)
        vel_y = 0
    else:
        vel_y = random.uniform(-1.0, 1.0)
        vel_x = 0
    init_lin_vel = (vel_x, vel_y, 0)
    

    init_ang_vel = (0, 0, 0)  # Initial angular velocity, generally set to 0
    
    
    # Internal parameters - friction coefficient range
    obj_friction = random.uniform(0.01, 0.2)  # Object friction coefficient
    table_friction = 0.01  # Table friction coefficient

    obj_mass = 2.0  # Object mass reference value 
    
    # Add table
    table = scene.add_entity(
        morph=gs.morphs.Plane(  # Use a plane instead of a box, planes are typically fixed by default
            pos=(0.0, 0.0, 0.0),  # Plane at origin
        ),
        material=gs.materials.Rigid(
            friction=table_friction,  # Set friction coefficient in material
        ),
    )

    random_init_x = random.uniform(-0.1, 0.1)
    random_init_y = random.uniform(-0.1, 0.1)
    
    # Add sliding object (box) 
    if random_shape == "box":
        sliding_obj = scene.add_entity(
            morph=gs.morphs.Box(
                pos=(random_init_x, random_init_y, random_size/2.0),  
                size=(random_size, random_size, random_size),
            ),
            material=gs.materials.Rigid(
                friction=obj_friction,  
            ),
            surface=gs.surfaces.Default(
                color=random_color,
                vis_mode="visual",
            ),
        )
    else:
        assert 0 == 1

    # Create camera
    random_cam_angle = random.uniform(0, 0.5 * math.pi)
    random_cam_height = random.uniform(-1, 1)
    random_cam_lookat_x = random.uniform(-1, 1)
    random_cam_lookat_y = random.uniform(-1, 1)
    random_cam_lookat_z = random.uniform(-1, 0.1)
    cam = scene.add_camera(
        res=(960, 640),
        pos=(1.5 * np.cos(random_cam_angle), 1.5 * np.sin(random_cam_angle), 1.0 + 0.5 * random_cam_height),
        lookat=(0.1 * random_cam_lookat_x, 0.1 * random_cam_lookat_y, 0.1 + 0.2 * random_cam_lookat_z),
        fov=45,  
        GUI=False,
    )

    ########################## Build Scene ##########################
    scene.build()
    
    # Collect frame data
    frames = []
    seg_frames = []
    position_data = {}
    
    # Simulation parameters
    stabilization_frames = 20  # Stabilization for 20 frames
    simulation_frames = 180  # Formal simulation for 180 frames
    total_frames = stabilization_frames + simulation_frames
    print(f"Starting simulation, total frames: {total_frames} (stabilization: {stabilization_frames}, simulation: {simulation_frames})")
    
    # Start RGB recording
    cam.start_recording()
    
    # First let the box stabilize on the plane
    print(f"First stabilize box for {stabilization_frames} frames...")
    for i in range(stabilization_frames):
        scene.step()
        
        # Render and get various image data (stabilization phase also needs rendering)
        rgb, depth, seg, normal = cam.render(rgb=True, depth=True, segmentation=True, normal=False)
        
        # Save RGB and segmentation frames
        frames.append(rgb.copy())
        seg_frames.append(seg.copy())
        
        if i % 5 == 0:
            print(f"Stabilization phase: frame {i}")
            
        # Record box position
        try:
            if hasattr(sliding_obj, 'get_pos'):
                position = sliding_obj.get_pos()
                # Convert tensor to CPU numpy
                if isinstance(position, torch.Tensor):
                    position = position.cpu().numpy()
                position_data[i] = {
                    'phase': 'stabilization',
                    'position': position.tolist() if isinstance(position, np.ndarray) else str(position)
                }
            else:
                position_data[i] = {'frame': i, 'phase': 'stabilization', 'info': 'no get_pos method'}
        except Exception as e:
            position_data[i] = {'frame': i, 'phase': 'stabilization', 'error': str(e)}
    
    print("Box has stabilized, starting to apply initial velocity...")
    
    # Give the sliding object an initial velocity
    print(f"Setting initial linear velocity: {init_lin_vel}, angular velocity: {init_ang_vel}")
    
    # Try to use correct method to set linear and angular velocity separately
    if hasattr(sliding_obj, 'set_velocity'):
        sliding_obj.set_velocity(linear=init_lin_vel, angular=init_ang_vel)
        print("Successfully set velocity using set_velocity method")
    elif hasattr(sliding_obj, 'set_dofs_velocity'):
        # Correctly combine linear and angular velocity (not addition)
        combined_velocity = list(init_lin_vel) + list(init_ang_vel)  # Concatenate two lists
        sliding_obj.set_dofs_velocity(combined_velocity) # (3, 0, 0, 0, 0, 0)
        print(f"Successfully set velocity using set_dofs_velocity method: {combined_velocity}")
    else:
        raise AttributeError("Object has no set_velocity or set_dofs_velocity method")
        
    
    
    # Run simulation and collect data
    print("Starting formal simulation...")
    for i in range(simulation_frames):
        frame_idx = stabilization_frames + i  # Global frame index
        scene.step()
        
        # Render and get various image data
        rgb, depth, seg, normal = cam.render(rgb=True, depth=True, segmentation=True, normal=False)
        
        # Save RGB and segmentation frames
        frames.append(rgb.copy())
        seg_frames.append(seg.copy())
        
        if i % 20 == 0:
            print(f"Completed formal simulation frame: {i}/{simulation_frames}")
        
        # Record box position
        try:
            if hasattr(sliding_obj, 'get_pos'):
                position = sliding_obj.get_pos()
                # Convert tensor to CPU numpy
                if isinstance(position, torch.Tensor):
                    position = position.cpu().numpy()
                position_data[frame_idx] = {
                    'phase': 'simulation',
                    'position': position.tolist() if isinstance(position, np.ndarray) else str(position)
                }
            else:
                position_data[frame_idx] = {'frame': frame_idx, 'phase': 'simulation', 'info': 'no get_pos method'}
        except Exception as e:
            position_data[frame_idx] = {'frame': frame_idx, 'phase': 'simulation', 'error': str(e)}

    # Confirm collected frame count
    print(f"Collected {len(frames)} RGB frames and {len(seg_frames)} segmentation frames in total")

    # Set fps value
    fps = 30

    # Generate filenames
    base_filename = f'friction_sim_box_objFriction_{obj_friction:.2f}_tableFriction_{table_friction:.2f}_size_{random_size:.2f}_dir_{vel_x:.2f}_{vel_y:.2f}_id_{video_id}'
    rgb_filename = f'{base_filename}_rgb.mp4'
    seg_filename = f'{base_filename}_seg.mp4'
    pos_filename = f'{base_filename}_pos.json'
    meta_filename = f'{base_filename}_meta.json'
    
    # Save metadata
    metadata = {
        'shape': random_shape,
        'size': random_size,
        'mass_reference': obj_mass,
        'color': [float(c) for c in random_color],
        'obj_friction': obj_friction,
        'table_friction': table_friction,
        'initial_linear_velocity': init_lin_vel,
        'initial_angular_velocity': init_ang_vel,
        'initial_x': random_init_x,
        'initial_y': random_init_y,
        'cam_angle': random_cam_angle,
        'cam_height': random_cam_height,
        'cam_lookat_x': random_cam_lookat_x,
        'cam_lookat_y': random_cam_lookat_y,
        'cam_lookat_z': random_cam_lookat_z,
    }
    
    with open(os.path.join(output_dir, meta_filename), 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"Metadata saved to: {os.path.join(output_dir, meta_filename)}")

    # Save position data
    with open(os.path.join(output_dir, pos_filename), 'w') as f:
        json.dump(position_data, f, indent=2)
    print(f"Position data saved to: {os.path.join(output_dir, pos_filename)}")
    
    # Use camera native method to save RGB video
    cam.stop_recording(
        save_to_filename=os.path.join(output_dir, rgb_filename),
        fps=fps
    )
    
    # Use OpenCV to save segmentation video
    if len(seg_frames) > 0:
        try:
            h, w = seg_frames[0].shape[:2]
            seg_writer = cv2.VideoWriter(
                os.path.join(output_dir, seg_filename),
                cv2.VideoWriter_fourcc(*'mp4v'),
                fps,
                (w, h)
            )
            
            for frame in seg_frames:
                # Convert segmentation to binary mask
                mask = (frame > 0).astype(np.uint8) * 255
                # Convert to 3 channels
                frame_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
                seg_writer.write(frame_color)
                
            seg_writer.release()
            print(f"Segmentation video saved to: {os.path.join(output_dir, seg_filename)}")
        except Exception as e:
            print(f"Failed to save segmentation video: {e}")
    else:
        print("No segmentation frames collected, cannot create video")

    # Free memory
    frames.clear()
    seg_frames.clear()
    
    # Clean up GPU memory and resources
    scene.reset()
    del scene, cam, sliding_obj, table
    gs.destroy()
    torch.cuda.empty_cache()
    gc.collect()

    print(f"Completed simulation video{video_id}!")
    print(f"Data saved to: {os.path.join(output_dir, base_filename)}")
    
    return

# Main program entry
if __name__ == "__main__":
    for video_id in range(NUM_VIDEOS):
        create_simulation(video_id)
    
    print("All simulations completed!") 