import abc
import os
import pickle
import warnings
from typing import List, Dict

import torch
from torch_geometric.data import Data
from tqdm import tqdm

from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict
from ltsgns_mp.envs.trajectory import Trajectory
from ltsgns_mp.envs.trajectory_collection import TrajectoryCollection
from ltsgns_mp.util.util import print_mem_usage_of_data_object


class AbstractDataloader(abc.ABC):
    def __init__(self, config: ConfigDict):

        self.config = config
        self.debug_config = config.debug

    def load(self) -> Dict[str, TrajectoryCollection]:
        """
        Load the preprocessed graphs. Returns a dictionary {split: [Task]}, where split is one of
        "train", "val" or "test".
        Args:

        Returns:

        """
        traj_dict = {mode: self._load_split(mode) for mode in self.split_modes}
        return traj_dict

    def _load_split(self, split: str) -> TrajectoryCollection:
        # load raw data
        rollout_data: List[ValueDict] = self._load_raw_data(split=split)

        traj_list: List[Trajectory] = []
        for index, raw_traj in enumerate(tqdm(rollout_data, desc=f"Loading {split.title()} Data...")):
            # if we don't want to load all tasks (e.g. for debugging)
            if self.debug_config.max_tasks_per_split is not None and index >= self.debug_config.max_tasks_per_split:
                break
            if self.config.eval_only and split == "train" and index >= 2:
                # we only need one training example for evaluation, so we can skip the rest
                break
            rollout_length: int = self._get_rollout_length(raw_traj=raw_traj)
            raw_traj: ValueDict = self._select_and_normalize_attributes(raw_traj=raw_traj)
            trajectory: List[Data] = []

            start_step = self.config.start_step
            for timestep in range(start_step, rollout_length):
                data_dict: ValueDict = self._build_data_dict(raw_traj=raw_traj, timestep=timestep)
                data: Data = self._build_graph(data_dict=data_dict)
                trajectory.append(data)

            # create a Task object from inner Data Object List
            traj_list.append(Trajectory(trajectory))
        return TrajectoryCollection(traj_list)

    @property
    def split_modes(self):
        if self.config.eval_only:
            return [keys.TRAIN, keys.TEST, keys.VAL]
        else:
            return [keys.TRAIN, keys.VAL]


    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    @abc.abstractmethod
    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        raise NotImplementedError("AbstractPreprocessor does not implement _get_rollout_length method")

    @abc.abstractmethod
    def _load_raw_data(self, split: str) -> List[ValueDict]:
        raise NotImplementedError("AbstractPreprocessor does not implement _load_raw_data method")

    @abc.abstractmethod
    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        """
        Removes unused attributes such as point cloud or poisson values. Also normalizes stuff if necessary (task level)
        Args:
            raw_task:

        Returns:

        """
        raise NotImplementedError("AbstractPreprocessor does not implement _select_and_normalize_attributes method")

    @abc.abstractmethod
    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Load for one timestep the correct tensors into data dict (indexed with timestep)
        Args:
            raw_task:
            timestep:

        Returns:

        """
        raise NotImplementedError("AbstractPreprocessor does not implement _build_data_dict method")

    @abc.abstractmethod
    def _build_graph(self, data_dict: ValueDict) -> Data:
        """
        Build Data object from data dict, i.e., actually build a graph from a dictionary of tensors.
        Args:
            data_dict:

        Returns:

        """
        raise NotImplementedError("AbstractPreprocessor does not implement _build_graph method")
