from torch.utils.data import Dataset

from mas_sat.graph.base import BaseGraph

class GraphDataset(Dataset):
    """
    Graph dataset by getting the observation from an environment
    and convert that to the original graph
    """
    def __init__(self, env, graph, dim) -> None:
        super().__init__()
        self.env = env
        self.graph = graph
        self.dim = dim

    def __getitem__(self, idx) -> BaseGraph:
        observation, _ = self.env.reset(idx=idx)
        data = self.graph.from_observation(observation, self.dim, original=True)
        self.env.close()
        return data

    def __len__(self) -> int:
        return self.env.unwrapped.len()
