import abc
import copy
from typing import List

import torch
from omegaconf import OmegaConf
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.util.convert_to_single_data_trajectory import convert_traj_to_data, compute_point_cloud_padding_size
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import node_type_mask
from ltsgns_mp.util.own_types import ConfigDict


class EvalIterator:
    """
    Abstract class for an eval iterator. Eval iterators are used to iterate over the data during evaluation.
    The whole trajectory are used for evaluation, and batched if possible.

    """

    def __init__(self, config, eval_trajs, context_size: int, device: str):
        """

        :param config:
        :param eval_trajs: Can Either be validation or test trajectories, depending on the config.
        :param context_size: The context size used for evaluation.
        :param device:
        """
        self.config = config
        self._eval_trajs = eval_trajs
        self.device = device
        self.context_size = context_size
        self._indices = self._build_indices(context_size)
        # convert trajs to single Data objects
        self._data_list = self.convert_trajs_to_data(eval_trajs)

        self._iterator = iter(self._data_list)

    def convert_trajs_to_data(self, trajs) -> List[Data]:
        data_list = []
        point_cloud_padding_size = compute_point_cloud_padding_size(trajs)
        context_type = self.config.context_type
        if context_type == "mixed":
            # if mixed, this is for training purpose of pointcloud context
            context_type = keys.POINT_CLOUD
        print("Preparing Eval Iterator...")
        for task_idx, traj in enumerate(tqdm(trajs, desc="Preparing Eval Iterator", disable=True)):
            data_traj = convert_traj_to_data(traj, self._indices, point_cloud_padding_size, context_type,
                                             self.config.anchor_index_as_feature, self.config.last_collider_as_feature)
            data_traj[keys.TASK_INDICES] = task_idx
            data_list.append(data_traj)
        return data_list

    def _build_indices(self, context_size: int):
        """
        Build the indices for the evaluation. This is done by sampling from the current bounds.
        :param context_size:
        :return:
        """
        indices = {
            keys.MESH: {"indices": list(range(context_size))},
            keys.POINT_CLOUD: {"indices": []},  # right now don't support Point cloud training.
            keys.EVALUATION: {
                "start_idx": context_size - 1,
                "stop_idx": None,
                "step": 1
            },
            "anchor_idx": context_size - 1
        }
        return OmegaConf.create(indices)



    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = copy.deepcopy(next(self._iterator))
            batch.to(self.device)
            return batch
        except StopIteration:
            # If the DataLoader's iterator is exhausted, reset it and raise StopIteration
            self._iterator = iter(self._data_list)
            raise StopIteration

    def __len__(self):
        return len(self._data_list)

    @property
    def num_tasks(self):
        return len(self._data_list)
