from pickle import load

from config import setup_logger
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

logger = setup_logger()


class MapDataset(Dataset):
    def __init__(self, folderpath: str, mode: str) -> None:
        super().__init__()

        with open(f"{folderpath}/{mode}_location.pickle", "rb") as f:
            self.locations = load(f)

        with open(f"{folderpath}/{mode}_length.pickle", "rb") as f:
            self.length_list = load(f)

    def __len__(self) -> int:
        return len(self.locations)

    def __getitem__(self, index: int):
        location = self.locations[index]
        acc = self.length_list[index]

        return location, acc


if __name__ == "__main__":
    dataset = MapDataset(folderpath="map_data/ja9847.tsp_100node_30num/", mode="train")
    dataloader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)

    for location, acc_length in dataloader:
        breakpoint()
