import numpy as np
from typing import Tuple
from typing import List
from rise import *
import math


def create_stair_voxel_terrain(
    center_size: float,
    stair_width: float,
    stair_height: float,
    num_stairs: int,
    voxel_size: float,
    material_name: str = "material_0",
) -> RS_StructureConfig:
    """
    Create stair voxel terrain for the environment. The staircase is shaped like an inverted pyramid
    with an NxN flat surface at the bottom center, surrounded by four sets of enclosed staircases on each side.

    Args:
        center_size: Size of the center flat platform in meters, creates an NxN platform.
        stair_width: Width of each stair step in meters.
        stair_height: Height of each stair step in meters.
        num_stairs: Number of stair levels ascending from the center.
        voxel_size: Size of each voxel in meters.
        material_name: Name of the material to use (must exist in environment config, default: "material_0").

    Returns:
        structure_config: Structure config for the stair terrain.
    """
    # Convert meters to voxel counts
    center_size_voxels = int(round(center_size / voxel_size))
    stair_width_voxels = int(round(stair_width / voxel_size))
    stair_height_voxels = int(round(stair_height / voxel_size))

    structure_config = RS_StructureConfig()
    structure_config.name = "stair_terrain"
    structure_config.is_fixed = True
    structure_config.voxel_size = voxel_size
    structure_config.origin_position = RVec3rf(0, 0, 0)
    structure_config.orientation = RQuat3rf(0, 0, 0, 1)

    # Add material reference once - all bodies will reference index 0
    structure_config.material_references.append(material_name)

    body_sid = 0

    # Create center platform (ground level)
    center_body = RS_StructureBodyConfig()
    center_body.body_sid = body_sid
    center_body.relative_origin_position = RVec3rf(0, 0, 0)
    center_body.relative_orientation = RQuat3rf(0, 0, 0, 1)
    center_body.x_voxels = center_size_voxels
    center_body.y_voxels = center_size_voxels
    center_body.z_voxels = 1
    total_voxels = center_size_voxels * center_size_voxels * 1
    for _ in range(total_voxels):
        center_body.material_reference_sid.append(0)
        center_body.segment_bid.append(0)
        center_body.segment_type.append(0)
    structure_config.bodies.append(center_body)
    body_sid += 1

    # Create stairs on all four sides
    # Each stair level forms a complete ring around the previous level
    for stair_level in range(num_stairs):
        # Current dimensions for this stair level
        current_size_voxels = center_size_voxels + 2 *(stair_level + 1) * stair_width_voxels
        next_size_voxels = (
            current_size_voxels + 2 * (stair_level + 1) * stair_width_voxels
        )
        # Calculate z position: stairs sit on top of center platform and stack up
        # Center platform top is at 0.5*voxel_size, each stair adds stair_height_voxels*voxel_size
        # z_pos = (
        #     0.5 * voxel_size + (stair_level + 0.5) * stair_height_voxels * voxel_size
        # )
        z_pos = stair_level * stair_height_voxels * voxel_size

        # North side stair (positive X direction)
        north_stair = RS_StructureBodyConfig()
        north_stair.body_sid = body_sid
        # Position at the north edge
        x_pos = (center_size_voxels + stair_level * stair_width_voxels) * voxel_size
        y_pos = - (next_size_voxels - current_size_voxels) / 2 * voxel_size
        north_stair.relative_origin_position = RVec3rf(x_pos, y_pos, z_pos)
        north_stair.relative_orientation = RQuat3rf(0, 0, 0, 1)
        north_stair.x_voxels = stair_width_voxels
        north_stair.y_voxels = current_size_voxels
        north_stair.z_voxels = stair_height_voxels
        total_voxels = stair_width_voxels * current_size_voxels * stair_height_voxels
        for _ in range(total_voxels):
            north_stair.material_reference_sid.append(0)
            north_stair.segment_bid.append(0)
            north_stair.segment_type.append(0)
        structure_config.bodies.append(north_stair)
        body_sid += 1

        # South side stair (negative X direction)
        south_stair = RS_StructureBodyConfig()
        south_stair.body_sid = body_sid
        # Mirror of north position
        x_pos = -((stair_level + 1) * stair_width_voxels) * voxel_size
        y_pos = -(next_size_voxels - current_size_voxels) / 2 * voxel_size
        south_stair.relative_origin_position = RVec3rf(x_pos, y_pos, z_pos)
        south_stair.relative_orientation = RQuat3rf(0, 0, 0, 1)
        south_stair.x_voxels = stair_width_voxels
        south_stair.y_voxels = current_size_voxels
        south_stair.z_voxels = stair_height_voxels
        total_voxels = stair_width_voxels * current_size_voxels * stair_height_voxels
        for _ in range(total_voxels):
            south_stair.material_reference_sid.append(0)
            south_stair.segment_bid.append(0)
            south_stair.segment_type.append(0)
        structure_config.bodies.append(south_stair)
        body_sid += 1

        # East side stair (positive Y direction)
        east_stair = RS_StructureBodyConfig()
        east_stair.body_sid = body_sid
        # Position at the east edge
        x_pos = -(next_size_voxels - current_size_voxels) / 2 * voxel_size
        y_pos = (center_size_voxels + stair_level * stair_width_voxels) * voxel_size
        east_stair.relative_origin_position = RVec3rf(x_pos, y_pos, z_pos)
        east_stair.relative_orientation = RQuat3rf(0, 0, 0, 1)
        east_stair.x_voxels = current_size_voxels
        east_stair.y_voxels = stair_width_voxels
        east_stair.z_voxels = stair_height_voxels
        total_voxels = current_size_voxels * stair_width_voxels * stair_height_voxels
        for _ in range(total_voxels):
            east_stair.material_reference_sid.append(0)
            east_stair.segment_bid.append(0)
            east_stair.segment_type.append(0)
        structure_config.bodies.append(east_stair)
        body_sid += 1

        # West side stair (negative Y direction)
        west_stair = RS_StructureBodyConfig()
        west_stair.body_sid = body_sid
        # Mirror of east position
        x_pos = -(next_size_voxels - current_size_voxels) / 2 * voxel_size
        y_pos = -((stair_level + 1) * stair_width_voxels) * voxel_size
        west_stair.relative_origin_position = RVec3rf(x_pos, y_pos, z_pos)
        west_stair.relative_orientation = RQuat3rf(0, 0, 0, 1)
        west_stair.x_voxels = current_size_voxels
        west_stair.y_voxels = stair_width_voxels
        west_stair.z_voxels = stair_height_voxels
        total_voxels = current_size_voxels * stair_width_voxels * stair_height_voxels
        for _ in range(total_voxels):
            west_stair.material_reference_sid.append(0)
            west_stair.segment_bid.append(0)
            west_stair.segment_type.append(0)
        structure_config.bodies.append(west_stair)
        body_sid += 1

    return structure_config


def create_random_terrain_areana(
    name: str,
    position: Tuple[float, float, float],
    size: Tuple[float, float],
    height_range: Tuple[float, float] = (0.0, 1.0),
    resolution: int = 32,
    seed: int = 42,
    orientation: Tuple[float, float, float, float] = (0, 0, 0, 1),
    material_name: str = "material_0",
    voxel_size: float = 0.01,
    is_fixed: bool = True,
) -> RS_StructureConfig:
    """
    create a random height map terrain areana using voxels

    parameters:
        name: the name of the object
        position: the position of the object in the global coordinate system (x, y, z)
        size: the size of the terrain (x_size, y_size)
        height_range: the range of the height (min_height, max_height)
        resolution: the resolution of the terrain (the number of points in each direction)
        seed: the seed of the random number generator
        orientation: the direction of the object (x, y, z, w)
        material_name: the name of the material
        voxel_size: the size of the voxel
        is_fixed: whether the object is fixed

    return:
        the RS_StructureConfig object
    """
    # set the random number generator
    np.random.seed(seed)

    # calculate the number of voxels
    x_voxels = int(np.ceil(size[0] / voxel_size))
    y_voxels = int(np.ceil(size[1] / voxel_size))
    max_height_voxels = int(np.ceil(height_range[1] / voxel_size))

    # create the structure config
    structure_config = RS_StructureConfig()
    structure_config.name = name
    structure_config.is_fixed = is_fixed
    structure_config.voxel_size = voxel_size
    structure_config.origin_position = RVec3rf(*position)
    structure_config.orientation = RQuat3rf(*orientation)
    structure_config.material_references.append(material_name)

    # create the volume config
    body_config = RS_StructureBodyConfig()
    body_config.body_sid = 0
    body_config.relative_origin_position = RVec3rf(0, 0, 0)
    body_config.relative_orientation = RQuat3rf(0, 0, 0, 1)
    body_config.x_voxels = x_voxels
    body_config.y_voxels = y_voxels
    body_config.z_voxels = max_height_voxels

    # initialize the voxel data as empty
    total_voxels = x_voxels * y_voxels * max_height_voxels
    for _ in range(total_voxels):
        body_config.material_reference_sid.append(RS_NULL_INDEX)
        body_config.segment_bid.append(RS_NULL_INDEX)
        body_config.segment_type.append(RS_NULL_INDEX)

    # generate the height map - use simple sine wave and noise combination
    heightmap = np.zeros((resolution, resolution))
    for i in range(resolution):
        for j in range(resolution):
            x_norm = i / (resolution - 1)
            y_norm = j / (resolution - 1)

            # base sine wave
            base_height = np.sin(x_norm * 4 * np.pi) * np.cos(y_norm * 4 * np.pi) * 0.5

            # add some random noise
            noise = np.random.random() * 0.3

            # combine and map to height range
            heightmap[i, j] = (base_height + noise + 0.5) * (
                height_range[1] - height_range[0]
            ) + height_range[0]

    # convert the height map to voxels
    for x in range(x_voxels):
        for y in range(y_voxels):
            # map to height map index
            hm_x = int(x * (resolution - 1) / x_voxels)
            hm_y = int(y * (resolution - 1) / y_voxels)

            # get the height value (voxel unit)
            height = int(np.ceil(heightmap[hm_x, hm_y] / voxel_size))

            # fill all voxels from the ground to the height
            for z in range(height + 1):
                if z < max_height_voxels:
                    idx = x + y * x_voxels + z * x_voxels * y_voxels
                    body_config.material_reference_sid[idx] = 0
                    body_config.segment_bid[idx] = 0
                    body_config.segment_type[idx] = 0  # 0 represents soft body

    # add the volume config to the structure
    structure_config.bodies.append(body_config)

    return structure_config
