import copy

import numpy as np
import torch
from omegaconf import OmegaConf
from torch_geometric.data import Data

from ltsgns_mp.envs.trajectory_collection import TrajectoryCollection
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import node_type_mask
from ltsgns_mp.util.own_types import ConfigDict

"""
This file contains the function convert_traj_to_data, which is used in the LTSGNS_MP algorithm.
From a trajectory object (basically a list of graphs), it creates a single data object, consisting of
 - the anchor graph, the graph at the anchor index which is input to the GNN
 - the anchor index in time
 - the context node positions, the node positions of all graphs in the trajectory, used for ground truth
 - the context collider positions, the collider positions of all graphs in the trajectory, used for ground truth
 - the context point cloud positions, the point cloud positions of all graphs in the trajectory, used for ground truth
 - the evaluation indices, where the evaluation for the test loss is performed
 - the mesh indices, where the mesh is used for training, these are the real context indices
 - the point cloud indices, where the point cloud is used for training, these are the real context indices for PC training
 
If the data trajectory is used in training (the trajectory iterator), the evaluation indices don't matter.
The mesh and point cloud indices indicate which context is available for the ltsgns_mp algorithm. MGN doesn't care
about the context.

"""


def build_idx_config(config, traj, context_type):
    """
    Build the idx_config for the trajectory. This is done by sampling from the current bounds.
    :param config:
    :param traj:
    :return:
    """
    context_shift, mesh_context_size, mesh_indices = get_mesh_context_indices(config, traj)

    if config.min_last_point_cloud is not None:
        point_cloud_context_size = int(
            np.random.uniform(config.min_last_point_cloud, config.max_last_point_cloud + 1))
        point_cloud_step = int(
            np.random.uniform(config.min_point_cloud_context_step, config.max_point_cloud_context_step + 1))
        point_cloud_indices = list(range(context_shift, point_cloud_context_size + context_shift, point_cloud_step))
    else:
        point_cloud_indices = []

    # evaluation not really used here, but for consistency
    evaluation_indices = list(range(len(traj)))
    # anchor index
    if context_type == keys.MESH:
        context_size = mesh_context_size
    elif context_type == keys.POINT_CLOUD:
        context_size = point_cloud_context_size
    else:
        raise ValueError("Invalid context type {}".format(context_type))
    if config.anchor_index_mode == "first_context":
        anchor_index = context_shift
    elif config.anchor_index_mode == "last_context":
        anchor_index = context_shift + context_size - 1
    elif config.anchor_index_mode == "random":
        anchor_index = int(np.random.choice(range(context_shift, context_shift + context_size)))
    else:
        raise ValueError("Invalid anchor index mode {}".format(config.anchor_index_mode))

    idx_config = {
        keys.MESH: {
            "indices": mesh_indices,
        },
        keys.POINT_CLOUD: {
            "indices": point_cloud_indices,
        },
        keys.EVALUATION: {
            "indices": evaluation_indices,
        },
        "anchor_idx": anchor_index,
    }
    # convert to omegaconf
    idx_config = OmegaConf.create(idx_config)
    return idx_config


def get_mesh_context_indices(config, traj):
    # need to subtract 1, since the anchor index is not part of the context (but the vel target of the prev time step is)
    mesh_context_size = int(
        np.random.uniform(config.min_mesh_context_size, config.max_mesh_context_size + 1)) - 1

    # the first indices are the context indices
    context_shift = int(np.random.uniform(0, config.max_context_shift + 1))
    mesh_indices = list(range(context_shift, mesh_context_size + context_shift))
    # have some random later contexts
    num_sporadic_mesh_contexts = int(np.random.uniform(0, config.max_sporadic_mesh_contexts + 1))
    sporadic_mesh_indices = list(np.random.choice(range(context_shift + mesh_context_size + 1, len(traj) - 1),
                                                  num_sporadic_mesh_contexts, replace=False))
    # cast to standard int and sort
    sporadic_mesh_indices = sorted([int(x) for x in sporadic_mesh_indices])
    mesh_indices.extend(sporadic_mesh_indices)
    return context_shift, mesh_context_size, mesh_indices


def convert_traj_to_data(traj, idx_config: ConfigDict, point_cloud_padding_size: int | None, context_type: str,
                         anchor_index_as_feature: bool, last_collider_as_feature: bool = False) -> Data:
    """
    Converts a trajectory to a single data object. The additional information of the rest of the trajectory is
    stored in the data object in additional fields that can be accessed by the keys in the keys.py file.
    Used for evaluation of the trained models.
    Args:
        traj: A full trajectory, i.e., a list of graphs
        idx_config: A config dict containing the indices for the anchor mesh, and evaluation indices for evaluation
            meshes and point clouds
        point_cloud_padding_size: The size to which the point clouds are padded. can be None if no point_clouds are present
        context_type: The type of context used for training. Can be "mesh" or "point_cloud"
        anchor_index_as_feature: If true, the (normalized) anchor index is added as an x feature
        last_collider_as_feature: If true, the last collider is added as an x feature relative to the current position

    Returns:

    """

    anchor_graph = copy.deepcopy(traj[idx_config.anchor_idx])
    anchor_graph[keys.ANCHOR_INDICES] = idx_config.anchor_idx

    # add positions of the mesh nodes over time
    node_positions = torch.stack([graph[keys.POSITIONS][node_type_mask(graph, key=keys.MESH)]
                                  for graph in traj])
    node_positions = node_positions.unsqueeze(0)  # add batch dimension
    anchor_graph[keys.CONTEXT_NODE_POSITIONS] = node_positions
    if keys.COLLIDER in traj[0].node_type_description:
        # add positions of the collider over time
        collider_positions = torch.stack([graph[keys.POSITIONS][node_type_mask(graph, key=keys.COLLIDER)]
                                          for graph in traj])
        collider_positions = collider_positions.unsqueeze(0)  # add batch dimension
    else:
        collider_positions = None
    anchor_graph[keys.CONTEXT_COLLIDER_POSITIONS] = collider_positions
    if "visual_collider_vertices" in anchor_graph.keys():
        anchor_graph["visual_collider_vertices"] = torch.stack([graph["visual_collider_vertices"] for graph in traj])
    # if point cloud positions are available, add them
    if keys.POINT_CLOUD in traj[0]:
        point_cloud_positions = [graph[keys.POINT_CLOUD] for graph in traj]
        # nan-pad point_clouds to the same size
        padded_point_cloud_positions = torch.full((len(point_cloud_positions),
                                                   point_cloud_padding_size,
                                                   point_cloud_positions[0].shape[-1]), float("nan"))
        for idx, point_cloud in enumerate(point_cloud_positions):
            padded_point_cloud_positions[idx, :point_cloud.shape[0]] = point_cloud
        point_cloud_positions = padded_point_cloud_positions
        point_cloud_positions = point_cloud_positions.unsqueeze(0)  # add batch dimension
        anchor_graph[keys.CONTEXT_POINT_CLOUD_POSITIONS] = point_cloud_positions
    # add indices for which contexts are used for evaluation
    for idx_type in [keys.MESH, keys.POINT_CLOUD, keys.EVALUATION]:
        add_indices(anchor_graph, idx_type, len(traj), idx_config)
    anchor_graph[keys.CONTEXT_TYPE] = context_type

    if anchor_index_as_feature:
        anchor_index_feature = torch.tensor([idx_config.anchor_idx / len(traj)], dtype=torch.float)
        anchor_index_feature = anchor_index_feature.unsqueeze(0)
        # repeat num nodes times
        anchor_index_feature = anchor_index_feature.repeat(anchor_graph.x.shape[0], 1)
        anchor_graph.x = torch.concatenate([anchor_graph.x, anchor_index_feature], dim=1)
        anchor_graph.x_description += [keys.ANCHOR_INDICES] * anchor_index_feature.shape[1]
    if last_collider_as_feature and collider_positions is not None:
        last_collider = collider_positions[:, -1]
        rel_last_collider = last_collider - collider_positions[:, idx_config.anchor_idx]
        # remove batch dim, since we add it to the x features
        rel_last_collider = rel_last_collider[0]
        collider_features = torch.zeros((anchor_graph.x.shape[0], rel_last_collider.shape[1]))
        collider_features[node_type_mask(anchor_graph, keys.COLLIDER)] = rel_last_collider
        anchor_graph.x = torch.cat([anchor_graph.x, collider_features], dim=1)
        anchor_graph.x_description += ["rel_last_collider"] * collider_features.shape[1]

    return anchor_graph


def add_indices(anchor_graph, idx_type, traj_length, idx_config: ConfigDict):
    idx_config = idx_config[idx_type]
    empty_indices = torch.zeros(traj_length, dtype=torch.bool)
    if "indices" in idx_config:
        empty_indices[idx_config.indices] = True
    elif "start_idx" in idx_config:
        step = idx_config.step if "step" in idx_config else 1
        if "stop_idx" not in idx_config or idx_config["stop_idx"] is None:
            empty_indices[idx_config.start_idx::step] = True
        empty_indices[idx_config.start_idx:idx_config.stop_idx:step] = True
    else:
        raise ValueError("Invalid config for {} indices".format(idx_type))
    # add batch dimension
    empty_indices = empty_indices.unsqueeze(0)
    anchor_graph[f"{idx_type}_indices"] = empty_indices


def compute_point_cloud_padding_size(trajs: TrajectoryCollection) -> int | None:
    """
    Computes the maximum point cloud size over all trajectories. Used for padding the point clouds to the same size.
    If no point clouds are present, returns None.
    :param trajs: Collection of trajectories
    :return: The maximum point cloud size over all trajectories, or None if no point clouds are present
    """
    if keys.POINT_CLOUD not in trajs[0][0]:
        return None
    padding_size = 0
    for traj in trajs:
        for graph in traj:
            padding_size = max(padding_size, graph[keys.POINT_CLOUD].shape[0])
    return padding_size
