from torch.utils.data import Dataset

import numpy as np
from PIL import Image

DATASET_ROOT = '/database/dkim/VPR_datasets/MSLS/'  #DK
GT_ROOT = '/database/dkim/gsv_cities/datasets/' #DK

class MSLS(Dataset):
    def __init__(self, input_transform = None, path_return = False):
        

        self.input_transform = input_transform
        self.path_return = path_return

        self.dbImages = np.load(GT_ROOT+'msls_val/msls_val_dbImages.npy')
        self.qIdx = np.load(GT_ROOT+'msls_val/msls_val_qIdx.npy')
        self.qImages = np.load(GT_ROOT+'msls_val/msls_val_qImages.npy')
        self.ground_truth = np.load(GT_ROOT+'msls_val/msls_val_pIdx.npy', allow_pickle=True)
        
        # reference images then query images
        self.images = np.concatenate((self.dbImages, self.qImages[self.qIdx]))
        self.num_references = len(self.dbImages)
        self.num_queries = len(self.qImages[self.qIdx])
    
    def __getitem__(self, index):
        img = Image.open(DATASET_ROOT + self.images[index])

        if self.input_transform:
            img = self.input_transform(img)

        if self.path_return:
            img_path = DATASET_ROOT + self.images[index]
            return img, index, img_path

        return img, index, self.images[index]

    def __len__(self):
        return len(self.images)