from typing import List, Tuple, Dict

import numpy as np
import plotly.graph_objects as go
import torch
from omegaconf import OmegaConf
from torch_geometric.data import Data

from ltsgns_mp.recording.visualizations.util.get_plotly_camera_coords import show_figure_in_dash
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import node_type_mask, edge_type_mask, create_radius_edges
from ltsgns_mp.util.own_types import ConfigDict
from ltsgns_mp.util.util import to_numpy


def visualize_trajectory(eval_traj: Data,
                         predicted_traj: torch.Tensor,
                         visualization_config: ConfigDict,
                         ) -> go.Figure:
    """
    Visualize a trajectory of graphs using plotly.
    Args:
        eval_traj: Data object containing the ground truth trajectory, the context and optional point cloud data.
            Has the anchor graph at the current timestep with the edge information
        predicted_traj: Tensor of shape (num_timesteps, num_nodes, dim_world) containing the predicted positions of the
            mesh nodes.
        visualization_config: ConfigDict containing the visualization configuration

    Returns: A plotly figure containing the visualization of the trajectory. For 2D trajectories, the z axis is set to 0
      and the camera angle is set to top down.
    """
    meshes = get_meshes(eval_traj,
                        predicted_traj)
    predicted_mesh_vertices, gth_mesh_vertices, mesh_faces, collider_vertices, collider_faces = meshes
    vertex_dict = {
    }
    if hasattr(eval_traj, keys.POINT_CLOUD):
        # has point cloud information!
        point_cloud_positions = to_numpy(eval_traj[keys.CONTEXT_POINT_CLOUD_POSITIONS][0])
        point_cloud_positions[~to_numpy(eval_traj.point_cloud_indices[0])] = np.nan
        predicted_mesh_vertices, gth_mesh_vertices, collider_vertices, point_cloud_positions = subsample_meshes(
            (predicted_mesh_vertices, gth_mesh_vertices, collider_vertices, point_cloud_positions),
            num_frames=visualization_config.num_frames)
        vertex_dict[keys.POINT_CLOUD] = point_cloud_positions
    else:
        predicted_mesh_vertices, gth_mesh_vertices, collider_vertices = subsample_meshes(
            (predicted_mesh_vertices, gth_mesh_vertices, collider_vertices),
            num_frames=visualization_config.num_frames)

    vertex_dict[keys.MESH] = predicted_mesh_vertices
    vertex_dict[keys.REFERENCE_MESH] = gth_mesh_vertices
    if collider_vertices is not None:
        vertex_dict[keys.COLLIDER] = collider_vertices

    face_dict = {
        keys.MESH: mesh_faces,
        keys.REFERENCE_MESH: mesh_faces,
    }
    if collider_faces is not None and collider_vertices.shape[1] > 1:
        # small hack to not include Tissue Manipulation collider faces. Here there is a bug where the test collider is not loaded properly...
        face_dict[keys.COLLIDER] = collider_faces

    if keys.COLLIDER_MESH in eval_traj.edge_type_description:
        shifted_collider_idx_edges = eval_traj.edge_index[:, edge_type_mask(eval_traj, keys.COLLIDER_MESH)]
        # move the collider edges to start with index 0 -> subtract number mesh nodes
        shifted_collider_idx_edges[0] -= vertex_dict[keys.MESH].shape[1]

        edge_dict = {
            (keys.COLLIDER, keys.MESH): shifted_collider_idx_edges,
            #     (keys.MESH, keys.MESH): world_mesh_edges,
        }
    else:
        edge_dict = {}
    # if keys.CONTEXT_POINT_CLOUD_POSITIONS in eval_traj:
    #     point_cloud_positions = eval_traj[keys.CONTEXT_POINT_CLOUD_POSITIONS][0]
    #     point_cloud_mask = eval_traj.point_cloud_indices[0]
    #
    #     point_cloud_edge_list = []
    #     for step, current_point_cloud_positions in enumerate(point_cloud_positions):
    #         if point_cloud_mask[step]:
    #             invalid_points = torch.isnan(current_point_cloud_positions).any(dim=1)  # removed invalid/padded points
    #             current_point_cloud_positions = current_point_cloud_positions[~invalid_points]
    #
    #             mesh_positions = predicted_traj[step]
    #             point_cloud_mesh_edges = create_radius_edges(radius=0.08,
    #                                                          source_nodes=current_point_cloud_positions,
    #                                                          target_nodes=mesh_positions,
    #                                                          source_shift=0, )
    #
    #             point_cloud_edge_list.append(to_numpy(point_cloud_mesh_edges))
    #         else:
    #             point_cloud_edge_list.append(None)
    #     edge_dict[(keys.POINT_CLOUD, keys.MESH)] = point_cloud_edge_list

    plotting_kwargs = {
        keys.MESH: {"color": "black", "facecolor": "orange", "opacity": 0.4, },
        keys.COLLIDER: {"color": "gray", "opacity": 1},
        keys.REFERENCE_MESH: {"color": "turquoise", "opacity": 0.4},
        keys.POINT_CLOUD: {"color": "purple", "opacity": 0.3},
    }
    # make the size of nodes bigger if there is no faces
    if face_dict[keys.MESH].shape[0] == 0:
        plotting_kwargs[keys.MESH]["size"] = 10
    if face_dict[keys.REFERENCE_MESH].shape[0] == 0:
        plotting_kwargs[keys.REFERENCE_MESH]["size"] = 10

    fig = get_figure(vertex_dict=vertex_dict,
                     face_dict=face_dict,
                     edge_dict=edge_dict,
                     plotting_kwargs=plotting_kwargs,
                     limits=visualization_config.limits,
                     camera_config=visualization_config.camera,
                     timesteps=len(predicted_mesh_vertices),
                     frame_duration=get_frame_duration(visualization_config),
                     draw_z_floor=visualization_config.draw_z_floor, )

    if visualization_config.manual_camera_debug:
        show_figure_in_dash(fig)

    if visualization_config.output_type == "gif":
        raise NotImplementedError(
            "Gif output is not implemented yet. Previous approaches were slow as hell and not pretty")
    return fig


def subsample_meshes(meshes, num_frames):
    # check that all meshes are of the same length
    mesh_lengths = [len(mesh) for mesh in meshes if mesh is not None]
    assert len(set(mesh_lengths)) == 1, f"Meshes have different lengths: {mesh_lengths}"

    mesh = None
    for mesh in meshes:
        if mesh is not None:
            break
    assert mesh is not None, "All meshes are None"
    if num_frames > 0 & num_frames < len(mesh):  # Subsample the trajectory
        frame_indices = np.linspace(0, len(mesh) - 1, num_frames).astype(int)
    else:
        frame_indices = np.arange(len(mesh))
    output_meshes = []
    for mesh in meshes:
        if mesh is None:
            output_meshes.append(None)
        else:
            # subsample trajectory if not all frames should be visualized
            mesh = mesh[frame_indices]
            output_meshes.append(mesh)
    return output_meshes


def get_frame_duration(visualization_config):
    return int(1000 / visualization_config.fps)


def get_meshes(eval_traj: Data, predicted_traj: torch.Tensor) \
        -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None]:
    """
    Retrieve the mesh and collider data from the trajectory. We assume that the mesh topology does not change over time,
    so we can retrieve the faces from the first graph in the trajectory.
    The positions can change however, so we build a list over them given the provided stride.
    Args:
        eval_traj: Data object forming the trajectory in the standard format of using an anchor index
        predicted_traj: Tensor of shape (num_timesteps, num_nodes, dim_world) containing the predicted positions of the

    Returns: List of (predicted_mesh vertices, ground truth mesh vertices, mesh faces, collider vertices, collider faces)

    """
    # information about the mesh
    if hasattr(eval_traj, keys.MESH_FACES):
        mesh_faces = to_numpy(eval_traj[keys.MESH_FACES])
    else:
        mesh_faces = np.zeros([0, 3])
    # remove batch dim
    gth_mesh_vertices = to_numpy(eval_traj[keys.CONTEXT_NODE_POSITIONS][0])
    predicted_mesh_vertices = to_numpy(predicted_traj)

    # information about the collider
    collider_vertices = None
    collider_faces = None

    # for some tasks, visual collider vertex positions are stored externally.
    # These then do not affect the simulation of the mesh, but are still nice to visualize
    if hasattr(eval_traj, keys.CONTEXT_COLLIDER_POSITIONS):
        collider_vertices = to_numpy(eval_traj[keys.CONTEXT_COLLIDER_POSITIONS][0])

    # information about the collider faces
    if hasattr(eval_traj, keys.COLLIDER_FACES):
        collider_faces = to_numpy(eval_traj[keys.COLLIDER_FACES])

    return predicted_mesh_vertices, gth_mesh_vertices, mesh_faces, collider_vertices, collider_faces


def _get_vertices(trajectory, keys: str | List[str]) -> np.ndarray:
    """
    Retrieve the vertices of the given type or types from the trajectory.
    Args:
        trajectory:
        keys:

    Returns:

    """
    if isinstance(keys, str):
        keys = [keys]
    mesh_vertices = []
    for step in range(len(trajectory)):
        data = trajectory[step]
        node_mask = [node_type_mask(data, key) for key in keys]  # list of masks for each key
        node_mask = torch.stack(node_mask).any(dim=0)  # combine the masks, check if any of them is true

        mesh_vertices.append(to_numpy(data.pos[node_mask]))
    mesh_vertices = np.array(mesh_vertices)
    return mesh_vertices


def _get_edges(trajectory: List[Data], key: str) -> List[np.array] | None:
    """
    Retrieve the edges of the given type from the trajectory. Returns None if the edges are not present.
    Args:
        trajectory: List of Data objects forming the trajectory.
        key: The edge type to retrieve

    Returns: List of edges per step if the edges exist, None otherwise.

    """
    if key not in trajectory[0].edge_type_description:
        return None
    mesh_collider_edges = []
    for step in range(len(trajectory)):
        data = trajectory[step]
        mesh_collider_edges.append(to_numpy(data.edge_index[:, edge_type_mask(data, key)]))
    return mesh_collider_edges


def get_figure(vertex_dict: Dict[str, np.ndarray],
               face_dict: Dict[str, np.ndarray],
               edge_dict: Dict[Tuple[str, str], List[np.ndarray | None]],
               plotting_kwargs: Dict[str, Dict[str, str | float]],
               limits: ConfigDict,
               camera_config: ConfigDict,
               timesteps: int,
               frame_duration: int = 100,
               draw_z_floor: bool = False) -> go.Figure:
    """
    Create a plotly figure for the given data. If some of the data is None, it is not plotted.
    Args:
        vertex_dict: Dictionary of {key: vertex_positions} where vertex_positions is a list of vertex positions for each
            timestep. Each entry has shape (num_timesteps, num_nodes, {2,3}). May contain e.g.,
            * predicted_positions: List of predicted node/vertex positions of the predicted mesh for each timestep.
            * collider_positions: List of collider node/vertex positions for each timestep.
            * reference_positions: List of reference node/vertex positions for each timestep.
            * point_cloud_positions: List of point cloud positions for each timestep.
        face_dict: Dictionary of {key: faces} where faces is a list of face indices. Each entry has shape (num_faces, 3)
            and does not depend on the timestep. Must match keys with vertex_dict to determine the mesh positions.
            If provided, will plot the corresponding mesh using plotly.graph_objects.Mesh3d. May contain e.g.,
            * mesh_faces: List of mesh faces.
            * collider_faces: List of collider faces.

        edge_dict: Dictionary of {key: edges} where edges is a list of edge indices. Each entry has shape (2, num_edges)
            and does not depend on the timestep. Must match keys with vertex_dict to determine the edge positions.

        kwargs: Dictionary containing the plotting kwargs for each key, such as color, opacity, etc.

        limits: ConfigDict containing the limits of the plot.
        timesteps: Number of timesteps to visualize. Usually the length of the trajectory.
        limits: ConfigDict containing the limits of the plot.
            Is a dictionary {xlim: [min, max], ylim: [min, max], zlim: [min, max]}
        animation_config: ConfigDict containing the animation configuration
        frame_duration: Duration of each frame in milliseconds
        draw_z_floor: Whether to draw a floor at z=0

    Returns: Plotly figure

    """
    # Create the 3D figure
    sliders_dict = {
        "active": 0,
        "yanchor": "top",
        "xanchor": "left",
        "currentvalue": {
            "font": {"size": 20},
            "prefix": "Step: ",
            "visible": True,
            "xanchor": "right"
        },
        "transition": {"duration": frame_duration,
                       "easing": "cubic-in-out"},
        "pad": {"b": 10, "t": 50},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": []}

    frames = []
    for timestep in range(timesteps):
        traces = []

        # add the vertices, and maybe the faces of the objects if provided
        for key in vertex_dict:
            title = key.replace('_', ' ').title()
            kwargs = plotting_kwargs[key]
            timestep_vertices = vertex_dict[key][timestep]
            node_ids = [f'node_{i}'.lower() for i in range(len(timestep_vertices))]
            vertex_trace = _get_scatter_trace(vertex_positions=timestep_vertices,
                                              name=f"{title} Vertices",
                                              text=node_ids,
                                              textposition="top center",
                                              **kwargs
                                              )
            traces.append(vertex_trace)
            if key in face_dict:
                timestep_faces = face_dict[key]
                # check if faces exist
                if timestep_faces.shape[0] != 0:
                    edge_trace = _get_edge_trace_from_faces(vertex_positions=timestep_vertices,
                                                            faces=timestep_faces,
                                                            name=f"{title} Edges",
                                                            width=2,
                                                            **kwargs
                                                            )
                    traces.append(edge_trace)
                    face_trace = _get_face_trace(vertex_positions=timestep_vertices,
                                                 faces=timestep_faces,
                                                 name=f"{title}",
                                                 **kwargs
                                                 )

                    traces.append(face_trace)

        # add edges between different kinds of objects
        for (vertex_key1, vertex_key2), edge_list in edge_dict.items():
            title = f"{vertex_key1} {vertex_key2} Edges"
            timestep_vertices1 = vertex_dict[vertex_key1][timestep]
            timestep_vertices2 = vertex_dict[vertex_key2][timestep]
            if len(edge_list.shape) == 2:
                # no time step given
                timestep_edges = edge_list
            else:
                timestep_edges = edge_list[timestep]
            timestep_edges = to_numpy(timestep_edges)
            edge_trace = _get_edge_trace_from_vertex_relation(in_vertices=timestep_vertices1,
                                                              out_vertices=timestep_vertices2,
                                                              edges=timestep_edges,
                                                              name=title,
                                                              color="darkslategray",
                                                              )
            traces.append(edge_trace)

        if draw_z_floor:
            # add a floor at z=0
            floor_trace = _get_face_trace(vertex_positions=np.array([[limits.xlim[0], limits.ylim[0], 0],
                                                                     [limits.xlim[1], limits.ylim[0], 0],
                                                                     [limits.xlim[1], limits.ylim[1], 0],
                                                                     [limits.xlim[0], limits.ylim[1], 0]]),
                                          faces=np.array([[0, 1, 2],
                                                          [0, 2, 3]]),
                                          name="Floor",
                                          color="black",
                                          opacity=0.5,
                                          )
            traces.append(floor_trace)

        frame_name = f"Frame {timestep}"
        frame = go.Frame(data=traces, name=frame_name)
        frames.append(frame)

        slider_step = _get_slider_step(frame_duration, frame_name, timestep)
        sliders_dict["steps"].append(slider_step)

    fig = _build_figure(frames, sliders_dict, frame_duration=frame_duration,
                        limits=limits, camera_config=camera_config)

    return fig


def _get_slider_step(frame_duration, frame_name, timestep):
    slider_step = {"args": [
        [frame_name],  # need to have a correspondence here to tell which frames to animate
        {"frame": {"frame_duration": frame_duration,
                   "redraw": True  # must be set to True to update the plot
                   },
         "mode": "immediate",
         "transition": {"frame_duration": frame_duration}}
    ],
        "label": timestep,
        "method": "animate"}
    return slider_step


def _build_figure(frames, sliders_dict, frame_duration: int, limits: ConfigDict, camera_config: ConfigDict):
    """
    Builds the figure for the animation.
    Args:
        frames: 
        sliders_dict: 
        frame_duration: frame_duration of the animation in milliseconds. Converts to fps as 1000/frame_duration.
        limits: Dictionary of {xlim: [min, max], ylim: [min, max], zlim: [min, max]}

    Returns:

    """
    fig = go.Figure(
        data=frames[0].data,
        layout=go.Layout(
            scene=dict(
                xaxis=dict(range=list(limits.xlim), showgrid=False, showticklabels=False),
                yaxis=dict(range=list(limits.ylim), showgrid=False, showticklabels=False),
                zaxis=dict(range=list(limits.zlim), showgrid=False, showticklabels=False),
                aspectmode='cube'  # equal aspect ratio for all axes
            ),
            sliders=[sliders_dict],
            updatemenus=[dict(
                type="buttons",
                x=0,  # position of the "Play" button
                y=0,  # position of the "Play" button
                buttons=[dict(label="Play",
                              method="animate",
                              args=[None, {"frame": {"duration": frame_duration,
                                                     "redraw": True  # must be true
                                                     },
                                           "fromcurrent": True,
                                           "transition": {"duration": frame_duration}}],
                              )])]),
        frames=frames,
    )
    fig.update_scenes(xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)
    fig.update_layout(scene_camera=OmegaConf.to_container(camera_config))
    return fig


def _get_scatter_trace(vertex_positions: np.ndarray, showlegend=True, text=None, textposition="top center", name="Vertices", **kwargs):
    size = kwargs.get("size", 2)
    color = kwargs.get("color", "black")

    # remove padding if there is any
    vertex_positions = vertex_positions[~np.isnan(vertex_positions).any(axis=1)]
    if vertex_positions.shape[-1] == 2:
        z = np.zeros_like(vertex_positions[:, 0])
    else:
        z = vertex_positions[:, 2]

    scatter_trace = go.Scatter3d(x=vertex_positions[:, 0],
                                 y=vertex_positions[:, 1],
                                 z=z,
                                 mode="markers",  # change to "markers+text" for text
                                 marker={"size": size, "color": color},
                                 name=name,
                                 text=text,
                                 textposition=textposition,
                                 showlegend=showlegend)
    return scatter_trace


def _get_edge_trace_from_faces(vertex_positions, faces, showlegend=True, name="Edges", **kwargs):
    """
    Returns a trace for the edges of the mesh faces
    Args:
        vertex_positions:
        faces:
        showlegend:
        name:
        color:

    Returns:

    """
    color = kwargs.get("color", "black")
    width = kwargs.get("width", 2)

    faces = _rescale_indices(index_array=faces)
    num_faces = faces.shape[0]
    edge_x_positions = np.full(shape=4 * num_faces, fill_value=None)
    edge_y_positions = np.full(shape=4 * num_faces, fill_value=None)
    edge_z_positions = np.full(shape=4 * num_faces, fill_value=None)
    edge_x_positions[0::4] = vertex_positions[faces[:, 0], 0]
    edge_x_positions[1::4] = vertex_positions[faces[:, 1], 0]
    edge_x_positions[2::4] = vertex_positions[faces[:, 2], 0]
    edge_y_positions[0::4] = vertex_positions[faces[:, 0], 1]
    edge_y_positions[1::4] = vertex_positions[faces[:, 1], 1]
    edge_y_positions[2::4] = vertex_positions[faces[:, 2], 1]

    # check for 2D or 3D
    if vertex_positions.shape[-1] == 2:
        edge_z_positions[0::4] = np.zeros_like(vertex_positions[faces[:, 0], 0])
        edge_z_positions[1::4] = np.zeros_like(vertex_positions[faces[:, 1], 0])
        edge_z_positions[2::4] = np.zeros_like(vertex_positions[faces[:, 2], 0])
    else:
        edge_z_positions[0::4] = vertex_positions[faces[:, 0], 2]
        edge_z_positions[1::4] = vertex_positions[faces[:, 1], 2]
        edge_z_positions[2::4] = vertex_positions[faces[:, 2], 2]

    edge_trace = go.Scatter3d(x=edge_x_positions,
                              y=edge_y_positions,
                              z=edge_z_positions,
                              mode="lines",
                              line=dict(color=color, width=width),
                              name=name,
                              showlegend=showlegend)
    return edge_trace


def _get_edge_trace_from_vertex_relation(in_vertices, out_vertices, edges, name: str, **kwargs):
    """
    Returns a trace for the edges of the mesh faces from the edge indices, which are applied to the in and out vertices.
    Args:
        in_vertices: List of vertex positions for the in vertices
        out_vertices: List of vertex positions for the out vertices
        edges:
        name:

    Returns:

    """
    color = kwargs.get("color", "black")

    if edges is None:
        # return an empty scatter
        edge_trace = go.Scatter3d(x=[None],
                                  y=[None],
                                  z=[None],
                                  mode="lines",
                                  line=dict(color=color, width=1),
                                  name=name,
                                  showlegend=True)
    else:
        num_edges = edges.shape[1]
        edge_x_positions = np.full(shape=3 * num_edges, fill_value=None)
        edge_y_positions = np.full(shape=3 * num_edges, fill_value=None)
        edge_x_positions[0::3] = in_vertices[edges[0], 0]
        edge_y_positions[0::3] = in_vertices[edges[0], 1]
        edge_x_positions[1::3] = out_vertices[edges[1], 0]
        edge_y_positions[1::3] = out_vertices[edges[1], 1]

        # check for 2D or 3D
        edge_z_positions = np.full(shape=3 * num_edges, fill_value=None)
        if in_vertices.shape[-1] == 2:
            edge_z_positions[0::3] = np.zeros_like(in_vertices[edges[0], 0])
            edge_z_positions[1::3] = np.zeros_like(out_vertices[edges[1], 0])
        else:
            edge_z_positions[0::3] = in_vertices[edges[0], 2]
            edge_z_positions[1::3] = out_vertices[edges[1], 2]

        edge_trace = go.Scatter3d(x=edge_x_positions,
                                  y=edge_y_positions,
                                  z=edge_z_positions,
                                  mode="lines",
                                  line=dict(color=color, width=1),
                                  name=name,
                                  showlegend=True)
    return edge_trace


def _get_face_trace(vertex_positions, faces, showlegend=True, name="Faces",
                    **kwargs):
    color = kwargs.get("facecolor", kwargs.get("color", "lightblue"))
    opacity = kwargs.get("opacity", 0.5)

    if vertex_positions.shape[-1] == 2:
        z = np.zeros_like(vertex_positions[:, 0])
    else:
        z = vertex_positions[:, 2]
    faces = _rescale_indices(index_array=faces)
    face_trace = go.Mesh3d(
        x=vertex_positions[:, 0],
        y=vertex_positions[:, 1],
        z=z,
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        color=color,
        opacity=opacity,
        showlegend=showlegend,
        name=name
    )
    return face_trace


def _rescale_indices(index_array: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
    """
    Adjusts the indices of the index array such that the smallest index is 1.
    Indices will be in range to be in the range [0, index_entities - 1].
    Args:
        index_array: An array of integer indices of arbitrary shape

    Returns:

    """
    if np.prod(index_array.shape) == 0:  # works for np and torch
        # empty index :(
        return index_array
    # check that the indices of the faces are not shifted
    min_index = index_array.min()
    if min_index > 0:
        index_array = index_array - min_index
    return index_array


def plotly_figure_to_images(fig, image_format='png'):
    """
    Convert a Plotly figure to a list of images.

    Args:
        fig: Plotly figure object.
        image_format: The format of the images (e.g., 'png', 'jpeg', 'svg', etc.).

    Returns:
        A list of image bytes.
    """
    print("writing image")
    fig.write_image("test.png")
    print("finished writing image")
