from typing import Optional, Tuple, Union, List

import torch

from torch_geometric import transforms
from torch_geometric.data import Batch, Data

from ltsgns_mp.util import keys


def get_one_hot_features_and_types(input_list: Union[List[int], List[List[int]]],
                                   device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Builds one-hot feature tensor indicating the edge/node type from a list of numbers per type or indices per type.
    Assumes that the types are ordered in the same way as the input list.

    Args:
        input_list: List of numbers of nodes per type or list of lists of indices per type
        device: Device to compute on

    Returns: A tuple of two tensors containing the one-hot features and the type as number
        features: One-hot features Tensor, e.g., "(0,0,1,0)" for a node of type 2
        types: Tensor containing the type as number, e.g., "2" for a node of type 2
    """
    if isinstance(input_list[0], int):  # the input is a list of counts
        counts = input_list
        indices_per_type = [list(range(sum(counts[:i]), sum(counts[:i + 1]))) for i in range(len(counts))]
    else:  # the input is a list of list of indices
        indices_per_type = input_list
    total_num = sum(len(indices) for indices in indices_per_type)
    num_types = len(indices_per_type)
    features = torch.zeros(total_num, num_types, device=device)
    types = torch.zeros(total_num, device=device)

    for type_idx, indices in enumerate(indices_per_type):
        features[indices, type_idx] = 1
        types[indices] = type_idx
    return features, types


def add_second_order_dynamics(data_dict, timestep, raw_traj):
    if timestep == 0:
        data_dict[keys.PREV_MESH_POS] = data_dict[keys.MESH]
    else:
        if isinstance(raw_traj[keys.MESH][timestep - 1], torch.Tensor):
            data_dict[keys.PREV_MESH_POS] = raw_traj[keys.MESH][timestep - 1]
        else:
            data_dict[keys.PREV_MESH_POS] = torch.tensor(raw_traj[keys.MESH][timestep - 1], dtype=torch.float32)
    return data_dict

