import numpy as np
from noise import pnoise2

import rise


def generate_mountain_floor(
    x_values: int,
    y_values: int,
    slope: float = 0.1,
    center_radius: float = 3,
    max_height: float = 2,
) -> np.ndarray:
    """
    Generate a funnel like mountain floor.
    """
    floor = np.zeros((y_values, x_values))
    center_x = x_values // 2
    center_y = y_values // 2
    for x in range(x_values):
        for y in range(y_values):
            dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
            if dist > center_radius:
                floor[y, x] = min(slope * (dist - center_radius), max_height)
    return floor


def generate_noisy_mountain_floor(
    x_values: int,
    y_values: int,
    slope: float = 0.08,
    center_radius: float = 3,
    noise_amplitude: float = 0.02,
    max_height: float = 1.5,
) -> np.ndarray:
    """
    Generate a noisy, funnel like mountain floor.
    """
    floor = np.zeros((y_values, x_values))
    center_x = x_values // 2
    center_y = y_values // 2
    for x in range(x_values):
        for y in range(y_values):
            dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
            if dist > center_radius:
                noise = np.random.uniform(-1, 1)
                floor[y, x] = (
                    min(slope * (dist - center_radius), max_height)
                    + noise * noise_amplitude
                )
    return floor


def generate_perlin_noise_floor(
    x_values: int,
    y_values: int,
    octaves: int = 6,
    persistence: float = 0.5,
    lacunarity: float = 2.0,
    seed: int = 42,
) -> np.ndarray:
    """
    Generate a 2D numpy array representing floor heights using Perlin noise.

    Args:
        x_values (int): The width of the floor (number of columns in the array).
        y_values (int): The height of the floor (number of rows in the array).
        octaves (int, optional): The number of levels of detail in the noise.
            Higher values produce more detailed noise. Defaults to 6.
        persistence (float, optional): The amplitude of each successive octave relative to the previous one.
            Controls how quickly the amplitudes decrease. Defaults to 0.5.
        lacunarity (float, optional): The frequency of each successive octave relative to the previous one.
            Controls how quickly the frequencies increase. Defaults to 2.0.
        seed (int, optional): The seed value for random number generation, ensuring reproducibility of noise.
            Defaults to 42.

    Returns:
        np.ndarray: A 2D numpy array representing the floor, with values generated by Perlin noise.
    """
    floor = np.zeros((y_values, x_values))

    for y in range(y_values):
        for x in range(x_values):
            floor[y][x] = pnoise2(
                x / x_values,
                y / y_values,
                octaves=octaves,
                persistence=persistence,
                lacunarity=lacunarity,
                repeatx=x_values,
                repeaty=y_values,
                base=seed,
            )
    return floor


def generate_wave_floor(
    x_values: int,
    y_values: int,
    trapezoid_base_length: int = 10,
    trapezoid_top_length: int = 5,
    trapezoid_height: float = 0.3,
    x_direction: bool = True,
) -> np.ndarray:
    """
    Generate a wave-like floor with trapezoidal protrusions and trenches in the x or y direction.

    Args:
        x_values (int): The width of the floor (number of columns in the array).
        y_values (int): The height of the floor (number of rows in the array).
        trapezoid_base_length (int, optional): The base length of the trapezoids.
            Controls the distance between the two inclined sides of the trapezoid. Defaults to 10.
        trapezoid_top_length (int, optional): The top length of the trapezoids.
            Controls the flat portion on top of the trapezoid. Defaults to 5.
        trapezoid_height (float, optional): The height of the trapezoids.
            Controls how tall the trapezoids are from base to top. Defaults to 0.3.
        x_direction (bool, optional): If True, the wave is oriented in the x direction.
            If False, the wave is oriented in the y direction. Defaults to True.

    Returns:
        np.ndarray: A 2D numpy array representing the floor, with trapezoid-shaped
        protrusions and trenches in the specified direction.
    """

    # Initialize the floor with zeros
    floor = np.zeros((y_values, x_values))

    # Calculate center of the height to ensure the center is zero
    center_x = x_values // 2

    # Calculate trapezoid properties
    half_base_length = trapezoid_base_length / 2
    half_top_length = trapezoid_top_length / 2

    for x in range(x_values):
        # Calculate the x-offset for wave pattern based on horizontal position (j)
        x_relative = (x - center_x) % trapezoid_base_length
        phase = 1 if ((x - center_x) // trapezoid_base_length) % 2 == 0 else -1

        if x_relative < half_base_length - half_top_length:
            # Sloped side of the trapezoid (ascending)
            floor[:, x] = (
                x_relative / (half_base_length - half_top_length)
            ) * trapezoid_height
        elif x_relative < (half_base_length + half_top_length):
            # Flat top of the trapezoid
            floor[:, x] = trapezoid_height
        else:
            # Sloped side of the trapezoid (descending)
            floor[:, x] = (
                (trapezoid_base_length - x_relative)
                / (half_base_length - half_top_length)
            ) * trapezoid_height

        floor[:, x] *= phase

    if not x_direction:
        floor = floor.T
    return floor


def generate_2d_wave_floor(
    x_values: int,
    y_values: int,
    trapezoid_base_length: int = 10,
    trapezoid_top_length: int = 5,
    trapezoid_height: int = 0.3,
) -> np.ndarray:
    """
    Generate a wave-like floor with trapezoidal protrusions and trenches in the x or y direction.

    Args:
        x_values (int): The width of the floor (number of columns in the array).
        y_values (int): The height of the floor (number of rows in the array).
        trapezoid_base_length (int, optional): The base length of the trapezoids.
            Controls the distance between the two inclined sides of the trapezoid. Defaults to 10.
        trapezoid_top_length (int, optional): The top length of the trapezoids.
            Controls the flat portion on top of the trapezoid. Defaults to 5.
        trapezoid_height (int, optional): The height of the trapezoids.
            Controls how tall the trapezoids are from base to top. Defaults to 5.

    Returns:
        np.ndarray: A 2D numpy array representing the floor, with trapezoid-shaped
        protrusions and trenches in the specified direction.
    """

    # Initialize the floor with zeros
    floor = np.zeros((y_values, x_values))

    # Calculate center of the height to ensure the center is zero
    center_x = x_values // 2
    center_y = y_values // 2

    # Calculate trapezoid properties
    half_base_length = trapezoid_base_length / 2
    half_top_length = trapezoid_top_length / 2

    for x in range(x_values):
        for y in range(y_values):
            # Calculate the x-offset for wave pattern based on horizontal position (j)
            x_relative = (x - center_x) % trapezoid_base_length
            y_relative = (y - center_y) % trapezoid_base_length
            phase = (
                1
                if (
                    ((x - center_x) // trapezoid_base_length) % 2 == 0
                    and ((y - center_y) // trapezoid_base_length) % 2 == 0
                )
                or (
                    ((x - center_x) // trapezoid_base_length) % 2 == 1
                    and ((y - center_y) // trapezoid_base_length) % 2 == 1
                )
                else -1
            )

            if (
                half_base_length - half_top_length
                <= x_relative
                <= half_base_length + half_top_length
                and half_base_length - half_top_length
                <= y_relative
                <= half_base_length + half_top_length
            ):
                # Flat top of the trapezoid
                floor[y, x] = trapezoid_height
            else:
                if x_relative < half_base_length - half_top_length:
                    # Sloped side of the trapezoid (ascending)
                    x_height = (
                        x_relative / (half_base_length - half_top_length)
                    ) * trapezoid_height
                elif x_relative < (half_base_length + half_top_length):
                    x_height = trapezoid_height
                else:
                    # Sloped side of the trapezoid (descending)
                    x_height = (
                        (trapezoid_base_length - x_relative)
                        / (half_base_length - half_top_length)
                    ) * trapezoid_height

                if y_relative < half_base_length - half_top_length:
                    # Sloped side of the trapezoid (ascending)
                    y_height = (
                        y_relative / (half_base_length - half_top_length)
                    ) * trapezoid_height
                elif y_relative < (half_base_length + half_top_length):
                    y_height = trapezoid_height
                else:
                    # Sloped side of the trapezoid (descending)
                    y_height = (
                        (trapezoid_base_length - y_relative)
                        / (half_base_length - half_top_length)
                    ) * trapezoid_height
                floor[y, x] = min(x_height, y_height)

            floor[y, x] *= phase

    return floor


def scale_floor_height(min_height: float, max_height: float, floor: np.ndarray) -> np.ndarray:
    min_floor, max_floor = floor.min(), floor.max()
    return ((floor - min_floor) / (max_floor - min_floor)) * (
        max_height - min_height
    ) + min_height


def clear_floor_center(
    floor: np.ndarray, clear_radius: float, smooth_radius: float
) -> np.ndarray:
    y_size, x_size = floor.shape
    x0 = x_size // 2
    y0 = y_size // 2

    # Create coordinate grids
    y_indices, x_indices = np.indices(floor.shape)
    dist = np.sqrt((x_indices - x0) ** 2 + (y_indices - y0) ** 2)

    # Create masks
    clear_mask = dist <= clear_radius
    smooth_mask = (dist > clear_radius) & (dist <= clear_radius + smooth_radius)

    # Clear the center
    floor[clear_mask] = 0

    # Pad the floor to handle edges
    padded_floor = np.pad(floor, pad_width=1, mode="edge")

    # Initialize smoothed floor
    smoothed_floor = np.zeros_like(floor)

    # Sum over the 3x3 neighborhood
    smoothed_floor = (
        padded_floor[0:-2, 0:-2]
        + padded_floor[0:-2, 1:-1]
        + padded_floor[0:-2, 2:]
        + padded_floor[1:-1, 0:-2]
        + padded_floor[1:-1, 1:-1]
        + padded_floor[1:-1, 2:]
        + padded_floor[2:, 0:-2]
        + padded_floor[2:, 1:-1]
        + padded_floor[2:, 2:]
    ) / 9.0

    # Replace values in the smooth_mask with the smoothed values
    floor[smooth_mask] = smoothed_floor[smooth_mask]

    return floor


def update_config_remove_floor(config: "rise.RS_Config") -> None:
    """
    Update a Rise configuration to disable the floor.
    """
    floor_config = config.floor_config
    floor_config.data_config.type = rise.RSE_FloorDataType.RSE_FLOOR_DATA_NONE
    floor_config.data_config.config = None


def update_config_with_flat_floor(config: "rise.RS_Config") -> None:
    """
    Update a Rise configuration to use a flat floor at z = 0.
    """
    floor_config = config.floor_config
    floor_config.data_config.type = rise.RSE_FloorDataType.RSE_FLOOR_DATA_FLAT
    floor_config.data_config.config = None


def update_config_with_array_floor(
    config: "rise.RS_Config", floor: np.ndarray, floor_size: float
) -> None:
    """
    Update a Rise configuration to use an array-based floor.

    The floor array is normalized into the uint16 range expected by the simulator,
    and h_min/h_max are set based on the array min/max.
    """
    if floor.ndim != 2:
        raise ValueError("floor must be a 2D array")

    floor_config = config.floor_config
    data_config = floor_config.data_config

    y_values, x_values = floor.shape
    array_config = rise.RS_ArrayFloorDataConfig()
    array_config.x_values = int(x_values)
    array_config.y_values = int(y_values)

    array_config.x_size = floor_size
    array_config.y_size = floor_size

    h_min = float(np.min(floor))
    h_max = float(np.max(floor))
    if h_max == h_min:
        h_max = h_min + 1e-6

    array_config.h_min = h_min
    array_config.h_max = h_max

    normalized = (floor - h_min) / (h_max - h_min)
    normalized = np.clip(normalized, 0.0, 1.0)

    height_array = np.round(normalized * 65535).astype(np.uint16).ravel(order="C").tolist()
    for height_entry in height_array:
        array_config.height.append(height_entry)

    data_config.type = rise.RSE_FloorDataType.RSE_FLOOR_DATA_ARRAY
    data_config.config = array_config
