import warnings
from typing import List, Optional, Dict

import torch
import torch_cluster
from torch_geometric import transforms
from torch_geometric.data import Batch, Data
from torch_geometric.utils import to_undirected, coalesce

from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


def node_type_mask(graph_or_batch: Batch | Data, key: str, as_index: bool = False) -> torch.Tensor:
    """

    Args:
        graph_or_batch:
        key:
        as_index: If True, return the indices of the mask instead of the mask itself

    Returns:

    """
    if isinstance(graph_or_batch, Batch):
        mask: torch.Tensor = graph_or_batch.node_type == graph_or_batch.node_type_description[0].index(key)
    elif isinstance(graph_or_batch, Data):
        mask: torch.Tensor = graph_or_batch.node_type == graph_or_batch.node_type_description.index(key)
    else:
        raise ValueError(f"Unknown type of batch: {type(graph_or_batch)}")

    if as_index:
        return torch.where(mask)[0]
    else:
        return mask


def edge_type_mask(graph_or_batch: Batch | Data, key: str) -> torch.Tensor:
    if isinstance(graph_or_batch, Batch):
        mask: torch.Tensor = graph_or_batch.edge_type == graph_or_batch.edge_type_description[0].index(key)
    elif isinstance(graph_or_batch, Data):
        mask: torch.Tensor = graph_or_batch.edge_type == graph_or_batch.edge_type_description.index(key)
    else:
        raise ValueError(f"Unknown type of batch: {type(graph_or_batch)}")
    return mask


def unpack_node_features(graph: Batch, node_type: str = keys.MESH) -> torch.Tensor:
    """
    Unpacking the node features of the nodes with agent_node_type from a homogeneous graph.
    :param graph:
    :param node_type:
    :return: tensor of shape [num_agent_nodes, num_features]
    """
    node_features = graph.x
    mask = node_type_mask(graph, node_type)
    result = node_features[mask]
    return result


def add_distances_from_positions(data_or_batch: Batch | Data, add_euclidian_distance: bool) -> Batch | Data:
    """
    Transform the node positions to the edges as relative distance together with (if needed) Euclidean norm and add
    them to the edge features
    :param data_or_batch:
    :return:
    """

    def _update_edge_type_description(edge_type_description, add_z_distance: bool, add_euclidian_distance: bool):
        edge_type_description.extend(["x_distance", "y_distance"])
        if add_z_distance:
            edge_type_description.append("z_distance")
        if add_euclidian_distance:
            edge_type_description.append("euclidian_distance")

    if data_or_batch.edge_index is None or data_or_batch.edge_index.shape[1] == 0:
        # there are no edges, so we can't add edge features. Do nothing in this case
        return data_or_batch

    if hasattr(data_or_batch, "edge_type_description"):
        add_z_distance = data_or_batch.pos.shape[1] == 3
        if isinstance(data_or_batch, Batch):
            for edge_type_description in data_or_batch.edge_type_description:
                _update_edge_type_description(edge_type_description=edge_type_description,
                                              add_z_distance=add_z_distance,
                                              add_euclidian_distance=add_euclidian_distance)
        else:
            _update_edge_type_description(edge_type_description=data_or_batch.edge_type_description,
                                          add_z_distance=add_z_distance,
                                          add_euclidian_distance=add_euclidian_distance)

    if add_euclidian_distance:
        data_transform = transforms.Compose([transforms.Cartesian(norm=False, cat=True),
                                             transforms.Distance(norm=False, cat=True)])
    else:
        data_transform = transforms.Cartesian(norm=False, cat=True)
    out_batch = data_transform(data_or_batch)
    return out_batch


def add_gaussian_noise(batch: Batch, node_type_description: str, sigma: float,
                       device: torch.device | str) -> Batch:
    """
    Adds gaussian noise to the position of the nodes of a certain type in a batch.
    Args:
        batch:
        node_type_description:
        sigma:
        device:

    Returns: The batch with the added noise on the nodes.

    """
    node_type = batch.node_type_description[0].index(node_type_description)
    indices = torch.where(batch.node_type == node_type)[0]
    num_pos_features = batch.pos.shape[1]
    noise = torch.randn(indices.shape[0], num_pos_features).to(device) * sigma
    batch.pos[indices] += noise
    return batch

def add_history_gaussian_noise(batch: Batch, node_type_description: str, sigma: float,
                       device: torch.device | str) -> Batch:
    node_type = batch.node_type_description[0].index(node_type_description)
    indices = torch.where(batch.node_type == node_type)[0]
    num_pos_features = batch.pos.shape[1]
    total_history_length = batch["history_pos"].shape[1] + 1  # current position as well
    history_sigma = sigma / torch.sqrt(torch.tensor(total_history_length, dtype=torch.float32))
    noise = torch.randn(indices.shape[0], total_history_length, num_pos_features).to(device) * history_sigma
    # cumulative sum -> random walk
    noise = noise.cumsum(dim=1)
    history_noise = noise[:, :-1, :]
    current_noise = noise[:, -1, :]
    batch["history_pos"][indices] += history_noise
    batch.pos[indices] += current_noise
    return batch

def add_history_vel_feature(batch: Batch):
    current_pos = batch.pos
    history_vels = []

    for hist_idx in range(batch["history_pos"].shape[1] - 1, -1, -1):
        prev_pos = batch["history_pos"][:, hist_idx, :]
        vel = current_pos - prev_pos
        history_vels.insert(0, vel)
        # order: from oldest to newest
        current_pos = prev_pos
    if len(history_vels) == 0:
        return batch
    history_vels = torch.stack(history_vels, dim=1)
    # shape of history_vels: [num_nodes, history_length, world_dim]
    # reshape to [num_nodes, history_length * world_dim]
    history_vels = history_vels.view(history_vels.shape[0], -1)
    batch.x = torch.cat((batch.x, history_vels), dim=1)
    for inner_list in batch.x_description:
        inner_list.extend(["history_velocities"] * history_vels.shape[1])
    return batch

def add_label(batch: Batch, second_order_dynamics: bool) -> Batch:
    mesh_mask = node_type_mask(batch, keys.MESH)

    if second_order_dynamics:
        # true x_{t+1} - true x_t - noisy x_t + true x_{t-1}
        vel_label = (batch[keys.NEXT_MESH_POS] - batch[keys.CURRENT_MESH_POS] - batch[keys.POSITIONS][mesh_mask] +
                     batch[keys.PREV_MESH_POS])
        # true x_{t+1} - 2*noisy x_t + true x_{t-1}
        pos_label = batch[keys.NEXT_MESH_POS] - 2 * batch[keys.POSITIONS][mesh_mask] + batch[keys.PREV_MESH_POS]
        label = 0.1 * pos_label + 0.9 * vel_label
    else:
        # true x_{t+1} - noisy x_t
        label = batch[keys.NEXT_MESH_POS] - batch[keys.POSITIONS][mesh_mask]
    batch["y"] = label
    return batch

def add_and_update_node_features(batch_or_graph: Batch | Data, second_order_dynamics: bool) -> Batch | Data:
    # check if mesh_z_pos is present and update it
    if isinstance(batch_or_graph, Batch):
        x_description = batch_or_graph.x_description[0]
    else:
        x_description = batch_or_graph.x_description
    mesh_mask = node_type_mask(batch_or_graph, keys.MESH)
    if keys.MESH + "_x_pos" in x_description:
        # update mesh_x_pos
        x_pos_index = x_description.index(keys.MESH + "_x_pos")
        batch_or_graph.x[mesh_mask, x_pos_index] = batch_or_graph[keys.POSITIONS][mesh_mask][:, 0]
    if keys.MESH + "_y_pos" in x_description:
        # update mesh_y_pos
        y_pos_index = x_description.index(keys.MESH + "_y_pos")
        batch_or_graph.x[mesh_mask, y_pos_index] = batch_or_graph[keys.POSITIONS][mesh_mask][:, 1]
    if keys.MESH + "_z_pos" in x_description:
        # update mesh_z_pos
        z_pos_index = x_description.index(keys.MESH + "_z_pos")
        batch_or_graph.x[mesh_mask, z_pos_index] = batch_or_graph[keys.POSITIONS][mesh_mask][:, 2]

    if not second_order_dynamics:
        # don't add velocity features if we don't have second order dynamics
        return batch_or_graph
    # noisy vel
    mesh_vel = batch_or_graph[keys.POSITIONS][mesh_mask] - batch_or_graph[keys.PREV_MESH_POS]
    all_vel = torch.zeros_like(batch_or_graph[keys.POSITIONS])
    all_vel[mesh_mask] = mesh_vel
    batch_or_graph.x = torch.cat((batch_or_graph.x, all_vel), dim=1)
    if isinstance(batch_or_graph, Batch):
        for inner_list in batch_or_graph.x_description:
            inner_list.extend([keys.VELOCITIES] * all_vel.shape[1])
    else:
        batch_or_graph.x_description.extend([keys.VELOCITIES] * all_vel.shape[1])
    return batch_or_graph


def remove_edge_distances(batch: Batch) -> Batch | Data:
    """
    remove the distances of the old edges if there are any. The distances are found via edge_type_description,
    which is a list of (identical) lists of strings over the graphs in the batch
    Args:
        batch:

    Returns:

    """
    # check if batch has edges at all
    if batch.edge_index is None or batch.edge_index.shape[1] == 0:
        return batch
    edge_distance_indices = [i for i, s in enumerate(batch.edge_type_description[0])
                             if s in ["x_distance", "y_distance", "z_distance", "euclidian_distance"]]
    if edge_distance_indices:
        # Create a list of all indices
        all_indices = torch.arange(batch.edge_attr.shape[1])
        # Get the indices that are to be kept
        indices_to_keep = torch.tensor([index for index in all_indices if index not in edge_distance_indices])
        # Select only the desired indices
        batch.edge_attr = batch.edge_attr[:, indices_to_keep]
        # batch.edge_type_description = [list(batch.edge_type_description[0][index] for index in indices_to_keep)]
    return batch


def build_edges(edge_index_dict: Dict[str, torch.Tensor]):
    """
    Build edges from edge_index_dict. Returns a tuple of edge_attr, edge_index, edge_type, edge_type_description,
    where edge_attr is a one-hot encoding of the edges, edge_index is the edge index, edge_type is an integer of the
    edge types and edge_type_description is a list of strings in the same order as edge_type.

    Args:
        edge_index_dict: Dictionary of edge indices. Keys are edge types, values are edge indices shaped (2, num_edges)

    Returns:

    """
    num_edges = []
    for key, value in edge_index_dict.items():
        num_edges.append(value.shape[1])
    edge_attr, edge_type = get_one_hot_features_and_types(num_edges)
    edge_index = torch.cat(tuple(edge_index_dict.values()), dim=1).long()
    edge_type_description = list(edge_index_dict.keys())
    return edge_attr, edge_index, edge_type, edge_type_description


def create_knn_edges(k: int, source_nodes: torch.Tensor, source_shift: int, target_nodes: Optional[torch.Tensor] = None,
                     target_shift: int = 0, max_collider_mesh_k_radius=None) -> torch.Tensor:
    """
    Create edges between source and target nodes based on a k-nearest neighbor graph.
    :param k: number of neighbors
    :param source_nodes: Source nodes to create edges from
    :param source_shift: Shift indices of source nodes by this amount
    :param target_nodes: Target nodes to create edges to. If None, source and target nodes are assumed to be the same
    :param target_shift: Shift indices of target nodes by this amount. Defaults to 0
    :return:  Edge indices between source and target nodes as a torch tensor of shape (2, num_edges)
    """

    if target_nodes is not None:
        # both source and target nodes provided
        source_target_edges = torch_cluster.knn(x=target_nodes,
                                                y=source_nodes,
                                                k=k)
        source_target_edges[0, :] += source_shift
        source_target_edges[1, :] += target_shift
    else:
        # only source nodes provided, so assume that source and target nodes are the same
        source_target_edges = torch_cluster.knn_graph(source_nodes, k=k)
        # add reverse edges
        source_target_edges = torch.cat((source_target_edges,
                                         torch.flip(source_target_edges, dims=(0,))),
                                        dim=1)
        # shift the edge indices to the correct node type, i.e., the first node type has indices 0 to num_nodes[0]-1
        source_target_edges += source_shift
    if max_collider_mesh_k_radius is not None:
        if target_nodes is not None:
            connected_source_nodes = source_nodes[source_target_edges[0] - source_shift]
            connected_target_nodes = target_nodes[source_target_edges[1] - target_shift]
            dist = torch.norm(connected_source_nodes - connected_target_nodes, dim=1)
            mask = dist < max_collider_mesh_k_radius
            source_target_edges = source_target_edges[:, mask]
    return source_target_edges


def create_radius_edges(radius: float, source_nodes: torch.Tensor, source_shift: int,
                        target_nodes: Optional[torch.Tensor] = None, target_shift: int = 0,
                        max_num_neighbors: int = 100) -> torch.Tensor:
    """
    Create edges between source and target nodes based on a radius.
    Args:
        radius: Radius to create edges within
        source_nodes: Source nodes to create edges from
        source_shift: Shift indices of source nodes by this amount
        target_nodes: Target nodes to create edges to. If None, source and target nodes are assumed to be the same
        target_shift: Shift indices of target nodes by this amount. Defaults to 0
        max_num_neighbors: Maximum number of neighbors/edges to consider per node. Defaults to 100. A very high radius
          combined with a low max_num_neighbors essentially corresponds to a k-nearest neighbor graph.

    Returns: Edge indices between source and target nodes as a torch tensor of shape (2, num_edges)

    """
    if target_nodes is not None:
        # both source and target nodes provided
        source_target_edges = torch_cluster.radius(x=target_nodes,
                                                   y=source_nodes,
                                                   r=radius,
                                                   max_num_neighbors=max_num_neighbors)
        source_target_edges[0, :] += source_shift
        source_target_edges[1, :] += target_shift
    else:
        # only source nodes provided, so assume that source and target nodes are the same
        source_target_edges = torch_cluster.radius_graph(source_nodes, r=radius,
                                                         max_num_neighbors=max_num_neighbors)
        # add reverse edges
        source_target_edges = torch.cat((source_target_edges,
                                         torch.flip(source_target_edges, dims=(0,))),
                                        dim=1)
        # shift the edge indices to the correct node type, i.e., the first node type has indices 0 to num_nodes[0]-1
        source_target_edges += source_shift
    return source_target_edges


def build_edges_from_data_dict(data_dict: ValueDict, pos_keys: List, num_nodes: List,
                               connectivity_setting: ConfigDict,
                               use_canonic_mesh_positions: bool) -> (torch.Tensor,
                                                                     torch.Tensor,
                                                                     torch.Tensor,
                                                                     List[str]):
    """
    Function to build the edge features from the data_dict
    :param data_dict: Dict containing all the data for a single timestep in torch tensor format.
    :param pos_keys: List of keys for the positions of the nodes
    :param num_nodes: List of numbers of nodes per type
    :param connectivity_setting: ConfigDict with the connectivity settings
    :param use_canonic_mesh_positions: bool: if relative mesh positions should be used as edge features
    :return: edge_attr, edge_type, edge_index, edge_type_description
    """

    edge_index_dict = _get_edge_indices(data_dict, pos_keys,
                                        num_nodes=num_nodes, connectivity_setting=connectivity_setting)

    # build edge_attr (one-hot)
    edge_attr, edge_index, edge_type, edge_type_description = build_edges(edge_index_dict)

    # add mesh_coordinates to mesh edges. This essentially encodes the distances between nodes in mesh space, meaning
    # that we do not need to differentiate between mesh and world edges in the model
    if use_canonic_mesh_positions:
        edge_attr = _add_canonic_mesh_positions(edge_index_dict, edge_attr, edge_type,
                                                input_mesh_edge_index=data_dict[keys.MESH_EDGE_INDEX],
                                                initial_mesh_positions=data_dict[keys.INITIAL_MESH_POSITIONS],
                                                include_euclidean_distance=True)
        edge_type_description.extend(["canonic_x_distance", "canonic_y_distance"])
        if data_dict["mesh"].shape[1] == 3:
            edge_type_description.append("canonic_z_distance")

        edge_type_description.append("canonic_euclidean_distance")

    return edge_attr, edge_type, edge_index, edge_type_description


def _get_edge_indices(data_dict, pos_keys, num_nodes, connectivity_setting):
    # save the offset for each node type. This is used to shift the edge indices to the correct node type
    index_shift_dict = {}
    for i, pos_key in enumerate(pos_keys):
        index_shift_dict[pos_key] = sum(num_nodes[0:i])
    edge_index_dict = {}
    # mesh edges: read from mesh_edge_index. Reverse edges for undirected graph
    mesh_edges = torch.cat((data_dict[keys.MESH_EDGE_INDEX], data_dict[keys.MESH_EDGE_INDEX][[1, 0]]), dim=1)
    # remove duplicates, this happens if the mesh already included reverse edges
    mesh_edges = to_undirected(mesh_edges)
    mesh_edges += index_shift_dict[keys.MESH]
    edge_index_dict[keys.MESH_MESH] = mesh_edges

    if "mesh_mesh_k" in connectivity_setting and connectivity_setting.mesh_mesh_k > 0:
        def remove_duplicate_edges_fast(edges1, edges2):
            # Create a tensor to store the result mask
            mask = torch.ones(edges1.shape[0], dtype=torch.bool)

            # Check for duplicate edges using broadcasting
            duplicates = torch.all(edges1[:, None, :] == edges2, dim=2)
            duplicates = torch.any(duplicates, dim=1)

            # Invert the mask to get non-duplicate edges
            mask[duplicates] = False

            # Extract unique edges
            unique_edges = edges1[mask]

            return unique_edges

        mesh_mesh_edges = create_knn_edges(k=connectivity_setting.mesh_mesh_k,
                                           source_nodes=data_dict[keys.MESH],
                                           source_shift=index_shift_dict[keys.MESH])
        real_mesh_edges = edge_index_dict[keys.MESH_MESH]
        # remove all edges that are already in the real mesh edges
        mesh_mesh_edges = remove_duplicate_edges_fast(mesh_mesh_edges.T, real_mesh_edges.T).T
        edge_index_dict[keys.WORLD_MESH] = mesh_mesh_edges


    if keys.COLLIDER in data_dict:
        if keys.COLLIDER_EDGE_INDEX in data_dict:  # add collider-collider edges: read from collider_edge_index
            collider_edges = torch.cat((data_dict[keys.COLLIDER_EDGE_INDEX],
                                        data_dict[keys.COLLIDER_EDGE_INDEX][[1, 0]]),
                                       dim=1)
            collider_edges += index_shift_dict[keys.COLLIDER]
            edge_index_dict[keys.COLLIDER_COLLIDER] = collider_edges

        collider_mesh_edges = get_collider_mesh_edges(collider_nodes=data_dict[keys.COLLIDER],
                                                      mesh_nodes=data_dict[keys.MESH],
                                                      index_shift_dict=index_shift_dict,
                                                      connectivity_setting=connectivity_setting)
        edge_index_dict[keys.COLLIDER_MESH] = collider_mesh_edges
        mesh_collider_edges = torch.flip(collider_mesh_edges, dims=(0,))  # reverse edges
        edge_index_dict[keys.MESH_COLLIDER] = mesh_collider_edges

    if keys.POINT_CLOUD in pos_keys:
        # point_cloud point_cloud edges
        if connectivity_setting.point_cloud_point_cloud_radius > 0:
            point_cloud_point_cloud_edges = create_radius_edges(
                radius=connectivity_setting.point_cloud_point_cloud_radius,
                source_nodes=data_dict[keys.POINT_CLOUD],
                source_shift=index_shift_dict[keys.POINT_CLOUD])
            edge_index_dict[keys.POINT_CLOUD_POINT_CLOUD] = point_cloud_point_cloud_edges

        # point_cloud mesh edges
        if connectivity_setting.point_cloud_mesh_radius > 0:
            point_cloud_mesh_edges = create_radius_edges(radius=connectivity_setting.point_cloud_mesh_radius,
                                                         source_nodes=data_dict[keys.POINT_CLOUD],
                                                         source_shift=index_shift_dict[keys.POINT_CLOUD],
                                                         target_nodes=data_dict[keys.MESH],
                                                         target_shift=index_shift_dict[keys.MESH])
            edge_index_dict[keys.POINT_CLOUD_MESH] = point_cloud_mesh_edges
            mesh_point_cloud_edges = torch.flip(point_cloud_mesh_edges, dims=(0,))  # reverse edges
            edge_index_dict[keys.MESH_POINT_CLOUD] = mesh_point_cloud_edges
    return edge_index_dict


def _add_canonic_mesh_positions(edge_index_dict: ValueDict, edge_attr: torch.Tensor, edge_type: torch.Tensor,
                                input_mesh_edge_index: torch.Tensor,
                                initial_mesh_positions: torch.Tensor,
                                include_euclidean_distance: bool = True) -> torch.Tensor:
    """
    Adds the relative mesh positions to the mesh edges (in contrast to the world edges) and zero anywhere else.
    Refer to MGN by Pfaff et al. 2020 for more details.
    Args:
        edge_index_dict: Dictionary containing all different edges. Used to find out which are the mesh-mesh edges.
        edge_attr: Current edge features
        edge_type: Tensor containing the edges types
        input_mesh_edge_index: Mesh edge index tensor
        initial_mesh_positions: Initial positions of the mesh nodes "mesh coordinates"
        include_euclidean_distance: If true, the Euclidean distance is added to the mesh positions

    Returns:
        edge_attr: updated edge features
    """
    mesh_edge_type = list(edge_index_dict.keys()).index(keys.MESH_MESH)
    indices = torch.where(edge_type == mesh_edge_type)[0]  # only create edges between mesh nodes
    mesh_edge_index = torch.cat((input_mesh_edge_index, input_mesh_edge_index[[1, 0]]), dim=1).long()
    # remove duplicates, this happens if the mesh already included reverse edges
    mesh_edge_index = to_undirected(mesh_edge_index)

    transformed_data = add_distances_from_positions(data_or_batch=Data(pos=initial_mesh_positions,
                                                                       edge_index=mesh_edge_index),
                                                    add_euclidian_distance=include_euclidean_distance)
    mesh_attr = transformed_data.edge_attr

    mesh_positions = torch.zeros(edge_attr.shape[0], mesh_attr.shape[1])  # fill other distances with 0
    mesh_positions[indices, :] = mesh_attr
    edge_attr = torch.cat((edge_attr, mesh_positions), dim=1)
    return edge_attr


def remove_duplicates_with_mesh_edges(mesh_edges: torch.Tensor, world_edges: torch.Tensor) -> torch.Tensor:
    """
    Removes the duplicates with the mesh edges have of the world edges that are created using a nearest neighbor search.
    (only MGN)
    To speed this up the adjacency matrices are used
    Args:
        mesh_edges: edge list of the mesh edges
        world_edges: edge list of the world edges

    Returns:
        new_world_edges: updated world edges without duplicates
    """
    import torch_geometric.utils as utils
    adj_mesh = utils.to_dense_adj(mesh_edges)
    if world_edges.shape[1] > 0:
        adj_world = utils.to_dense_adj(world_edges)
    else:
        adj_world = torch.zeros_like(adj_mesh)
    if adj_world.shape[1] < adj_mesh.shape[1]:
        padding_size = adj_mesh.shape[1] - adj_world.shape[1]
        padding_mask = torch.nn.ConstantPad2d((0, padding_size, 0, padding_size), 0)
        adj_world = padding_mask(adj_world)
    elif adj_world.shape[1] > adj_mesh.shape[1]:
        padding_size = adj_world.shape[1] - adj_mesh.shape[1]
        padding_mask = torch.nn.ConstantPad2d((0, padding_size, 0, padding_size), 0)
        adj_mesh = padding_mask(adj_mesh)
    new_adj = adj_world - adj_mesh
    new_adj[new_adj < 0] = 0
    new_world_edges = utils.dense_to_sparse(new_adj)[0]
    return new_world_edges


def recompute_external_edges(graph: Data, env_config: ConfigDict, device: torch.device | str) -> Data:
    """
    Recompute edges from data_dict, try to mimic the function from the env to get the edges in the first place.
    These edges include everything that is not mesh-mesh or collider-collider edges. This includes
    * mesh-collider,
    * collider-mesh,
    * world-mesh edges,
    * point_cloud-mesh,
    * mesh-point_cloud, and
    * point_cloud-point_cloud edges.
    Args:
        graph: One timestep of the simulation
        env_config: ConfigDict containing the connectivity setting
        device: Device to compute on
    Returns:

    """
    try:
        connectivity_setting = env_config.connectivity_setting
    except ValueError:
        warnings.warn("No connectivity setting found. Skipping Edge recomputation.")
        return graph

    mesh_nodes = graph.pos[node_type_mask(graph, keys.MESH)]
    if keys.COLLIDER in graph.node_type_description:
        collider_nodes = graph.pos[node_type_mask(graph, keys.COLLIDER)]
    else:
        collider_nodes = None
    mesh_edge_mask = edge_type_mask(graph, keys.MESH_MESH)
    mesh_edges = graph.edge_index[:, mesh_edge_mask]
    # compute offsets for the small radius graph computations

    index_shift_dict = {keys.MESH: 0,
                        keys.COLLIDER: mesh_nodes.shape[0]
                        }
    edge_index_dict = {keys.MESH_MESH: mesh_edges}

    if keys.COLLIDER_COLLIDER in graph.edge_type_description:
        collider_edge_mask = edge_type_mask(graph, keys.COLLIDER_COLLIDER)
        collider_edges = graph.edge_index[:, collider_edge_mask]
        edge_index_dict |= {keys.COLLIDER_COLLIDER: collider_edges}

    if keys.MESH_COLLIDER in graph.edge_type_description:  # collider-mesh edges and mesh-collider edges
        collider_mesh_edges = get_collider_mesh_edges(collider_nodes, mesh_nodes, index_shift_dict,
                                                      connectivity_setting)

        edge_index_dict[keys.COLLIDER_MESH] = collider_mesh_edges
        mesh_collider_edges = torch.flip(collider_mesh_edges, dims=(0,))  # reverse edges
        edge_index_dict[keys.MESH_COLLIDER] = mesh_collider_edges

    if keys.WORLD_MESH in graph.edge_type_description:  # mesh-mesh edges in world space
        world_mesh_edges = create_radius_edges(radius=connectivity_setting.world_mesh_radius,
                                               source_nodes=mesh_nodes,
                                               source_shift=index_shift_dict[keys.MESH])
        remove_duplicates_with_mesh_edges(edge_index_dict[keys.MESH_MESH], world_mesh_edges)
        # we do not need to add the reverse edges, as the edges in the mesh are undirected
        edge_index_dict[keys.WORLD_MESH] = world_mesh_edges

    if keys.POINT_CLOUD in graph.node_type_description:
        point_cloud_nodes = graph.pos[node_type_mask(graph, keys.POINT_CLOUD)]
        if collider_nodes is None:
            collider_nodes_shape = 0
        else:
            collider_nodes_shape = collider_nodes.shape[0]
        index_shift_dict[keys.POINT_CLOUD] = mesh_nodes.shape[0] + collider_nodes_shape

        if keys.POINT_CLOUD_POINT_CLOUD in graph.edge_type_description:  # point_cloud-point_cloud edges
            point_cloud_point_cloud_edges = create_radius_edges(
                radius=connectivity_setting.point_cloud_point_cloud_radius,
                source_nodes=point_cloud_nodes,
                source_shift=index_shift_dict[keys.POINT_CLOUD])
            # concat the flipped edges to the original edges
            edge_index_dict[keys.POINT_CLOUD_POINT_CLOUD] = point_cloud_point_cloud_edges

        if keys.POINT_CLOUD_MESH in graph.edge_type_description:  # point_cloud-mesh edges (and mesh-point_cloud edges)
            point_cloud_mesh_edges = create_radius_edges(radius=connectivity_setting.point_cloud_mesh_radius,
                                                         source_nodes=point_cloud_nodes,
                                                         source_shift=index_shift_dict[keys.POINT_CLOUD],
                                                         target_nodes=mesh_nodes,
                                                         target_shift=index_shift_dict[keys.MESH])
            edge_index_dict[keys.POINT_CLOUD_MESH] = point_cloud_mesh_edges
            mesh_point_cloud_edges = torch.flip(point_cloud_mesh_edges, dims=(0,))  # reverse edges
            edge_index_dict[keys.MESH_POINT_CLOUD] = mesh_point_cloud_edges

    # re-build all edges from the edge_index_dict.
    # This includes one-hot encodings and canonic (i.e., initial) distances

    # build edge_attr (one-hot)
    edge_attr, edge_index, edge_type, edge_type_description = build_edges(edge_index_dict)

    # add new attributes to device
    edge_index = edge_index.to(device)
    edge_attr = edge_attr.to(device)
    edge_type = edge_type.to(device)

    if env_config.use_canonic_mesh_positions:
        # reuse the canonic positions from the original graph, i.e., the initial distances between the mesh nodes
        mesh_edge_attr = graph.edge_attr[mesh_edge_mask]
        canonic_indices = [i for i, s in enumerate(graph.edge_type_description) if s.startswith("canonic")]
        original_canonic_mesh_positions = mesh_edge_attr[:, canonic_indices]
        # concat zeros to the remaining edges, as these are always the not mesh-mesh edges
        zero_vector = torch.zeros(edge_attr.shape[0] - original_canonic_mesh_positions.shape[0],
                                  len(canonic_indices), device=device)
        original_canonic_mesh_positions = torch.cat((original_canonic_mesh_positions, zero_vector)
                                                    , dim=0)
        edge_attr = torch.cat((edge_attr, original_canonic_mesh_positions), dim=1)
        canonic_descriptions = [s for s in graph.edge_type_description if s.startswith("canonic")]
        edge_type_description = edge_type_description + canonic_descriptions

    graph.__setattr__("edge_attr", edge_attr)
    graph.__setattr__("edge_type", edge_type)
    graph.__setattr__("edge_index", edge_index)
    graph.__setattr__("edge_type_description", edge_type_description)
    return graph


def get_collider_mesh_edges(collider_nodes, mesh_nodes, index_shift_dict, connectivity_setting):
    if connectivity_setting.collider_mesh_edge_creation == keys.KNN:
        collider_mesh_edges = create_knn_edges(k=connectivity_setting.collider_mesh_k,
                                               source_nodes=collider_nodes,
                                               source_shift=index_shift_dict[keys.COLLIDER],
                                               target_nodes=mesh_nodes,
                                               target_shift=index_shift_dict[keys.MESH])
    elif connectivity_setting.collider_mesh_edge_creation == keys.RADIUS:
        # add collider-mesh edges and mesh-collider edges
        collider_mesh_edges = create_radius_edges(radius=connectivity_setting.collider_mesh_radius,
                                                  source_nodes=collider_nodes,
                                                  source_shift=index_shift_dict[keys.COLLIDER],
                                                  target_nodes=mesh_nodes,
                                                  target_shift=index_shift_dict[keys.MESH])
    elif connectivity_setting.collider_mesh_edge_creation == "both":
        knn_collider_mesh_edges = create_knn_edges(k=connectivity_setting.collider_mesh_k,
                                               source_nodes=collider_nodes,
                                               source_shift=index_shift_dict[keys.COLLIDER],
                                               target_nodes=mesh_nodes,
                                               target_shift=index_shift_dict[keys.MESH],
                                               max_collider_mesh_k_radius=connectivity_setting.max_collider_mesh_k_radius)
        radius_collider_mesh_edges = create_radius_edges(radius=connectivity_setting.collider_mesh_radius,
                                                  source_nodes=collider_nodes,
                                                  source_shift=index_shift_dict[keys.COLLIDER],
                                                  target_nodes=mesh_nodes,
                                                  target_shift=index_shift_dict[keys.MESH])
        collider_mesh_edges = torch.cat((knn_collider_mesh_edges, radius_collider_mesh_edges), dim=1)
        # remove duplicates
        collider_mesh_edges = coalesce(collider_mesh_edges)
    else:
        raise ValueError(f"Unknown edge creation method: {connectivity_setting.collider_mesh_edge_creation}")
    return collider_mesh_edges


if __name__ == "__main__":
    collider_nodes = torch.tensor([[0.0, 0, 0],
                                   [1, 0, 0],
                                   [2, 0, 0],
                                   [3, 0, 0]])
    mesh_nodes = torch.tensor([[0.3, 1, 0],
                               [1.3, 1, 0],
                               [2.3, 1, 0],
                               [3.3, 1, 0],
                               [0.3, 2, 0],
                               [1.3, 2, 0],
                               [2.3, 2, 0],
                               [3.3, 2, 0]
                               ])
    index_shift_dict = {keys.MESH: 0,
                        keys.COLLIDER: 0
                        }
    connectivity_setting = ConfigDict({
        "collider_mesh_edge_creation": keys.KNN,
        "collider_mesh_radius": 1.5,
        "collider_mesh_k": 3
    })
    collider_mesh_edges = get_collider_mesh_edges(collider_nodes, mesh_nodes, index_shift_dict,
                                                    connectivity_setting)
    import matplotlib.pyplot as plt
    plt.scatter(collider_nodes[:, 0], collider_nodes[:, 1], c="r")
    plt.scatter(mesh_nodes[:, 0], mesh_nodes[:, 1], c="b")
    for i in range(collider_mesh_edges.shape[1]):
        start = collider_mesh_edges[0, i]
        end = collider_mesh_edges[1, i]
        plt.plot([collider_nodes[start, 0], mesh_nodes[end, 0]],
                 [collider_nodes[start, 1], mesh_nodes[end, 1]], c="g")
    plt.show()
