from __future__ import annotations  # necessary import for the create_subtask return typehint to work

from typing import List, Tuple

import torch
from torch_geometric.data import Data


class Trajectory:
    """
    Object which captures a List of Data elements.
    """

    def __init__(self, trajectory: List[Data]):
        self._data_list: List[Data] = trajectory

    @property
    def data_list(self) -> List[Data]:
        return self._data_list

    def get_subtrajectory(self, start_idx: int, end_idx: int, deepcopy: bool = False,
                          device: str | torch.device = None) -> Trajectory:
        """
        Returns a subtask of the current task with specific start and stop indices.
        """
        if deepcopy:
            sub_context_trajectory = [data.clone() for data in self.data_list[start_idx:end_idx]]
        else:
            sub_context_trajectory = self.data_list[start_idx:end_idx]

        traj = Trajectory(sub_context_trajectory)
        if device is not None:
            traj.to(device)
        return traj

    def to(self, device: torch.device | str) -> None:
        """
        Move this task to the given device.
        Args:
            device:

        Returns:

        """
        for data in self.data_list:
            data.to(device)

    def print_property(self, property_name: str) -> None:
        """
        Prints the given property of all datapoints in the trajectory. Useful for debugging purposes.
        Args:
            property_name: Name of the property to print

        Returns: None

        """
        print(f"Printing {property_name} of all datapoints in the trajectory")
        for data in self.data_list:
            print(data[property_name])

    def __len__(self) -> int:
        return len(self.data_list)

    def __getitem__(self, item):
        return self.data_list[item]

    def __repr__(self) -> str:
        return f"Task with {len(self)} datapoints like: {self[0]}"
