"""Map parser for the elevation grid salamander world.

This module generates a webots project directory with for the given map
file.
"""
from dataclasses import dataclass
from typing import NewType
from itertools import product
import networkx as nx
import argparse
from pathlib import Path
import json
import inspect
import salamander_env
import shutil
import random


Char = NewType("Char", str)
LAND_CHAR = "L"
SALAMANDER_CHAR = "S"
GOAL_CHAR = "G"
salamander_module_dir = Path(inspect.getfile(salamander_env)).parent

# Template parsing magic tokens
SALAMANDER_PROTO_PATH_TOKEN = "SALAMANDER_PROTO_PATH"
ELEVATION_GRID_HEIGHT_TOKEN = "ELEVATION_GRID_HEIGHT"
ELEVATION_GRID_XDIMENSION_TOKEN = "ELEVATION_GRID_XDIMENSION"
ELEVATION_GRID_YDIMENSION_TOKEN = "ELEVATION_GRID_YDIMENSION"
ELEVATION_GRID_XSPACING_TOKEN = "ELEVATION_GRID_XSPACING"
ELEVATION_GRID_YSPACING_TOKEN = "ELEVATION_GRID_YSPACING"
ELEVATION_GRID_THICKNESS_TOKEN = "ELEVATION_GRID_THICKNESS"
ELEVATION_GRID_TRANSLATION_TOKEN = "ELEVATION_GRID_TRANSLATION"
FLUID_TRANSLATION_TOKEN = "FLUID_TRANSLATION"
FLUID_SIZE_TOKEN = "FLUID_SIZE"
LEFT_WALL_TRANSLATION_TOKEN = "LEFT_WALL_TRANSLATION"
RIGHT_WALL_TRANSLATION_TOKEN = "RIGHT_WALL_TRANSLATION"
TOP_WALL_TRANSLATION_TOKEN = "TOP_WALL_TRANSLATION"
BOTTOM_WALL_TRANSLATION_TOKEN = "BOTTOM_WALL_TRANSLATION"
LR_WALL_SIZE_TOKEN = "LR_WALL_SIZE"
TB_WALL_SIZE_TOKEN = "TB_WALL_SIZE"
SALAMANDER_TRANSLATION_TOKEN = "SALAMANDER_TRANSLATION"
SALAMANDER_ROTATION_TOKEN = "SALAMANDER_ROTATION"
GOAL_TRANSLATION_TOKEN = "GOAL_TRANSLATION"
GOAL_RADIUS_TOKEN = "GOAL_RADIUS"


@dataclass
class ElevationGrid:
    height: list[float]
    x_dimension: float
    y_dimension: float
    x_spacing: float
    y_spacing: float
    thickness: float
    goal_position: tuple[int, int]
    salamander_position: tuple[int, int]


def get_softened_grid(
    grid: list[list[Char]],
    zero_distance: int,  # distance from land to reach maximum water depth
    max_depth: float,
) -> list[list[float]]:
    """Turn the text grid from a map file into a 'softened' array, where
    changes in height are gradual."""
    i_nodes = list(range(len(grid)))
    j_nodes = list(range(max(map(len, grid))))
    nodes = list(product(i_nodes, j_nodes))

    # Turn grid into a graph
    G = nx.Graph()
    for node in nodes:
        neighbors = [node]
        (i, j) = node
        if i > 0:
            neighbor = (i-1, j)
            neighbors.append(neighbor)
        if i < max(i_nodes):
            neighbor = (i+1, j)
            neighbors.append(neighbor)
        if j > 0:
            neighbor = (i, j-1)
            neighbors.append(neighbor)
        if j < max(j_nodes):
            neighbor = (i, j+1)
            neighbors.append(neighbor)

        for neighbor in neighbors:
            G.add_edge(node, neighbor)

    # Find the shortest distance between every pair of cells
    distances = dict(nx.shortest_path_length(G))

    # Identify land cells
    land_cells = set()
    for (i, j) in nodes:
        if grid[i][j] == LAND_CHAR:
            land_cells.add((i, j))

    # Get distances from each cell to a land cell
    distance_to_land = dict()
    for index in nodes:
        if index in land_cells:
            distance = 0
        else:
            land_distances = [
                distances[index][land_index]
                for land_index in land_cells
            ]
            distance = zero_distance if len(land_distances) == 0 else min(land_distances)
        distance_to_land[index] = distance

    # "Soften" heights by using the distance of water cells to land cells
    # to interpolate between zero height and land height
    soft_grid = [
        [0.0 for _ in j_nodes]
        for _ in i_nodes
    ]
    for (i, j) in nodes:
        depth = min(zero_distance, distance_to_land[(i, j)])/zero_distance
        softened_height = -depth*max_depth
        soft_grid[i][j] = softened_height

    return soft_grid


def get_elevation_grid(
    x_spacing: float,
    y_spacing: float,
    thickness: float,
    grid: list[list[Char]],
    zero_distance: int,
    max_depth: float,
) -> ElevationGrid:
    """See https://www.cyberbotics.com/doc/reference/elevationgrid"""
    # Get elevation grid
    height_grid = get_softened_grid(
        grid,
        zero_distance=zero_distance,
        max_depth=max_depth,
    )

    # Turn elevation grid into height list. The height list has the following
    # order for a x_dimension of size 2 and y_dimension of size 5
    #
    #  0 1 2 3 4
    #  5 6 6 7 8
    #
    height = list[float]()
    i_nodes = list(range(len(grid)))
    j_nodes = list(range(len(grid[0])))
    nodes = list(product(i_nodes, j_nodes))
    for i, j in nodes:
        height.append(height_grid[i][j])

    # Find object positions
    salamander_position: None | tuple[int, int] = None
    goal_position: None | tuple[int, int] = None
    for i, j in nodes:
        # Rotate position because the height array is filled y-first
        if grid[i][j] == SALAMANDER_CHAR:
            salamander_position = (j, i)
        if grid[i][j] == GOAL_CHAR:
            goal_position = (j, i)
    assert salamander_position is not None
    assert goal_position is not None

    # Assemble elevation grid
    x_dimension = len(grid[0])
    y_dimension = len(grid)
    elevation_grid = ElevationGrid(
        height=height,
        x_dimension=x_dimension,
        y_dimension=y_dimension,
        x_spacing=x_spacing,
        y_spacing=y_spacing,
        thickness=thickness,
        salamander_position=salamander_position,
        goal_position=goal_position,
    )
    return elevation_grid


@dataclass
class Map:
    map: list[str]
    x_spacing: float
    y_spacing: float
    thickness: float
    zero_distance: int
    max_depth: float
    salamander_rotation_rad: float


def webots_float_list(x: list[float]) -> str:
    """Helper function to format lists of floating points for webots"""
    x = [
        round(xi, 4)
        for xi in x
    ]
    return " ".join(map(str, x))


def get_world_file(
    salamander_proto_path: Path,
    elevation_grid: ElevationGrid,
    salamander_rotation_rad: float,
    world_template_wbt_path: Path,
) -> str:
    """Insert the given data into the world webots file."""
    # Open the world template
    with open(world_template_wbt_path, "rt") as fp:
        wbt = fp.read()

    # Replace proto path
    wbt = wbt.replace(
        SALAMANDER_PROTO_PATH_TOKEN,
        f'"{str(salamander_proto_path.resolve())}"',
    )

    # Replace elevation grid height array
    height_str = f"[ {webots_float_list(elevation_grid.height)} ]"
    wbt = wbt.replace(
        ELEVATION_GRID_HEIGHT_TOKEN,
        height_str,
    )

    # Compute map dimensions
    x_size = (elevation_grid.x_dimension-1)*elevation_grid.x_spacing
    y_size = (elevation_grid.y_dimension-1)*elevation_grid.y_spacing

    # Replace elevation grid translation
    x_translation = -x_size/2
    y_translation = -y_size/2
    translation = [x_translation, y_translation, 0.1]
    translation_str = webots_float_list(translation)
    wbt = wbt.replace(
        ELEVATION_GRID_TRANSLATION_TOKEN,
        translation_str,
    )

    # Replace other elevation grid parameters
    wbt = wbt.replace(
        ELEVATION_GRID_XDIMENSION_TOKEN,
        str(elevation_grid.x_dimension),
    )
    wbt = wbt.replace(
        ELEVATION_GRID_YDIMENSION_TOKEN,
        str(elevation_grid.y_dimension),
    )
    wbt = wbt.replace(
        ELEVATION_GRID_XSPACING_TOKEN,
        str(elevation_grid.x_spacing),
    )
    wbt = wbt.replace(
        ELEVATION_GRID_YSPACING_TOKEN,
        str(elevation_grid.y_spacing),
    )
    wbt = wbt.replace(
        ELEVATION_GRID_THICKNESS_TOKEN,
        str(elevation_grid.thickness),
    )

    # Place and scale fluid
    fluid_translation = [0.0, 0.0, -0.125]
    fluid_translation_str = webots_float_list(fluid_translation)
    wbt = wbt.replace(
        FLUID_TRANSLATION_TOKEN,
        fluid_translation_str,
    )
    size_str = webots_float_list([x_size, y_size, 0.25])
    wbt = wbt.replace(
        FLUID_SIZE_TOKEN,
        size_str,
    )

    # Place and scale walls
    lw_translation = [x_translation, 0.0, 0.0]
    rw_translation = [-x_translation, 0.0, 0.0]
    tw_translation = [0.0, -y_translation, 0.0]
    bw_translation = [0.0, y_translation, 0.0]
    lr_wall_size = [0.02, y_size, 0.7]
    tb_wall_size = [0.02, x_size, 0.7]
    lw_translation_str = webots_float_list(lw_translation)
    rw_translation_str = webots_float_list(rw_translation)
    tw_translation_str = webots_float_list(tw_translation)
    bw_translation_str = webots_float_list(bw_translation)
    lr_wall_size_str = webots_float_list(lr_wall_size)
    tb_wall_size_str = webots_float_list(tb_wall_size)
    wbt = wbt.replace(
        LEFT_WALL_TRANSLATION_TOKEN,
        lw_translation_str,
    )
    wbt = wbt.replace(
        RIGHT_WALL_TRANSLATION_TOKEN,
        rw_translation_str,
    )
    wbt = wbt.replace(
        TOP_WALL_TRANSLATION_TOKEN,
        tw_translation_str,
    )
    wbt = wbt.replace(
        BOTTOM_WALL_TRANSLATION_TOKEN,
        bw_translation_str,
    )
    wbt = wbt.replace(
        LR_WALL_SIZE_TOKEN,
        lr_wall_size_str,
    )
    wbt = wbt.replace(
        TB_WALL_SIZE_TOKEN,
        tb_wall_size_str,
    )

    # Place and rotate salamander
    # A height grid with n vertices on the x axis defines a surface of size n-1
    # We therefore have to rescale the salamander coordinates
    def get_world_xy(x: int, y: int) -> tuple[float, float]:
        """Helper function to transform text-map coordinates to world
        coordinates."""
        x_ratio = (elevation_grid.x_dimension-1)/elevation_grid.x_dimension
        y_ratio = (elevation_grid.y_dimension-1)/elevation_grid.y_dimension
        world_x = elevation_grid.x_spacing*(x*x_ratio+0.5)-x_size/2
        world_y = elevation_grid.y_spacing*(y*y_ratio+0.5)-y_size/2
        return world_x, world_y
    salamander_x, salamander_y = get_world_xy(
        elevation_grid.salamander_position[0],
        elevation_grid.salamander_position[1],
    )
    salamander_translation = [salamander_x, salamander_y, 0.2]
    salamander_rotation = [0, 0, 1, salamander_rotation_rad]
    salamander_translation_str = webots_float_list(salamander_translation)
    salamander_rotation_str = webots_float_list(salamander_rotation)
    wbt = wbt.replace(
        SALAMANDER_TRANSLATION_TOKEN,
        salamander_translation_str,
    )
    wbt = wbt.replace(
        SALAMANDER_ROTATION_TOKEN,
        salamander_rotation_str,
    )

    # Place goal
    goal_x, goal_y = get_world_xy(
        elevation_grid.goal_position[0],
        elevation_grid.goal_position[1],
    )

    goal_translation = [goal_x, goal_y, 0.0]
    goal_translation_str = webots_float_list(goal_translation)
    goal_radius = elevation_grid.x_spacing/2
    wbt = wbt.replace(
        GOAL_TRANSLATION_TOKEN,
        goal_translation_str,
    )
    wbt = wbt.replace(
        GOAL_RADIUS_TOKEN,
        str(goal_radius),
    )

    return wbt


def write_webots_project(
    input_map: Map,
    output_dir: Path,
):
    """Turn the given map into a webots project directory.

    The webots world files are available:
    - world/salamander.wbt
    """
    # Copy the original directory
    shutil.copytree(salamander_module_dir, output_dir, dirs_exist_ok=False)

    # Get elevation grid
    grid = [
        list(reversed([Char(char) for char in line]))
        for line in input_map.map
    ]
    elevation_grid = get_elevation_grid(
        x_spacing=input_map.x_spacing,
        y_spacing=input_map.y_spacing,
        thickness=input_map.thickness,
        grid=grid,
        zero_distance=input_map.zero_distance,
        max_depth=input_map.max_depth,
    )

    # Overwrite the world file
    salamander_proto_path = output_dir/"world"/"Salamander.proto"
    world_template_wbt_path=salamander_module_dir/"elevation_grid_templates"/"salamander.wbt"
    world_file = get_world_file(
        salamander_proto_path=salamander_proto_path,
        elevation_grid=elevation_grid,
        salamander_rotation_rad=input_map.salamander_rotation_rad,
        world_template_wbt_path=world_template_wbt_path,
    )
    world_path = output_dir/"world"/"salamander.wbt"
    with open(world_path, "wt") as fp:
        fp.write(world_file)

    # Overwrite the TCP world file
    salamander_proto_path = output_dir/"world"/"Salamandertcp.proto"
    world_template_wbt_path=salamander_module_dir/"elevation_grid_templates"/"salamandertcp.wbt"
    world_file = get_world_file(
        salamander_proto_path=salamander_proto_path,
        elevation_grid=elevation_grid,
        salamander_rotation_rad=input_map.salamander_rotation_rad,
        world_template_wbt_path=world_template_wbt_path,
    )
    world_path = output_dir/"world"/"salamandertcp.wbt"
    with open(world_path, "wt") as fp:
        fp.write(world_file)


def get_example_map_paths() -> list[Path]:
    map_paths = list(
        (salamander_module_dir/"elevation_grid_templates").glob("*.json")
    )
    return map_paths


def get_map(map_path: Path) -> Map:
    with open(map_path) as fp:
        map_json = json.load(fp)
    input_map = Map(
        map=map_json["map"],
        x_spacing=map_json["x_spacing"],
        y_spacing=map_json["y_spacing"],
        thickness=map_json["thickness"],
        zero_distance=map_json["zero_distance"],
        max_depth=map_json["max_depth"],
        salamander_rotation_rad=map_json["salamander_rotation_rad"],
    )
    return input_map


def _get_random_map_strs(seed: str) -> list[str]:
    """Helper function to generate a bottom-up map."""
    pieces = list()

    pieces.append("LL    ")
    pieces.append("LL  LL")
    pieces.append("    LL")
    pieces.append("  LL  ")
    pieces.append("LLLL  ")
    pieces.append("  LLLL")
    pieces.append("      ")
    pieces.append("LLLLLL")

    start = [
        "      ",
        "  G   ",
    ]
    end = [
        "  S   ",
        "      ",
    ]

    _random = random.Random(seed)
    map_len = _random.randint(5, 10)

    map_strs = list()
    map_strs.extend(start)
    for _ in range(map_len):
        piece = _random.choice(pieces)
        repeat_n = _random.randint(1, 3)
        for _ in range(repeat_n):
            map_strs.append(piece)
    map_strs.extend(end)

    return map_strs


def generate_random_map(seed: str) -> Map:
    """Generate a random bottom-up map."""
    map_strs = _get_random_map_strs(seed)
    generated_map = Map(
        map=map_strs,
        x_spacing=1.1,
        y_spacing=1.1,
        thickness=0.1,
        zero_distance=1.0,
        max_depth=0.25,
        salamander_rotation_rad=-1.5707,
    )
    return generated_map


def get_random_map(seed: str) -> Map:
    # Choose a random map template.
    # If we end up needing more maps, we can easily generate more:
    # - map augmentations (rotate, remove/add/duplicate random rows/cols)
    # - map generation (generate a random grid)
    _random = random.Random(seed)

    coin_flip = _random.choice([True, False])

    if coin_flip:
        map_paths = get_example_map_paths()
        _random = random.Random(seed)
        map_path = _random.choice(map_paths)
        chosen_map = get_map(map_path)
    else:
        chosen_map = generate_random_map(str(_random.random()))
    return chosen_map


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Salamander map parser',
        description='Turn map files into Webots projects.',
    )
    parser.add_argument(
        '--map_json',
        type=Path,
        required=True,
        help=(
            "JSON file with the map."""
        ),
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help=(
            "Non-existing output directory which will contain the Webots files."""
        ),
    )
    args = parser.parse_args()

    # Parse JSON file
    with open(args.map_json) as fp:
        map_json = json.load(fp)
    input_map = Map(
        map=map_json["map"],
        x_spacing=map_json["x_spacing"],
        y_spacing=map_json["y_spacing"],
        thickness=map_json["thickness"],
        zero_distance=map_json["zero_distance"],
        max_depth=map_json["max_depth"],
        salamander_rotation_rad=map_json["salamander_rotation_rad"],
    )

    write_webots_project(
        input_map=input_map,
        output_dir=args.output_dir,
    )
