from pathlib import Path
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

# NOTE: you need to download the Nordland dataset from  https://surfdrive.surf.nl/files/index.php/s/sbZRXzYe3l0v67W
# this link is shared and maintained by the authors of VPR_Bench: https://github.com/MubarizZaffar/VPR-Bench
# the folders named ref and query should reside in DATASET_ROOT path
# I hardcoded the image names and ground truth for faster evaluation
# performance is exactly the same as if you use VPR-Bench.
DATASET_ROOT = '/database/dkim/VPR_datasets/'  #DK
GT_ROOT = '/database/dkim/gsv_cities/datasets/' #DK

path_obj = Path(DATASET_ROOT)
if not path_obj.exists():
    raise Exception(f'Please make sure the path {DATASET_ROOT} to SVOX dataset is correct')

if not path_obj.joinpath('ref') or not path_obj.joinpath('query'):
    raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}')

class SVOXDataset(Dataset):
    def __init__(self, which_ds='SVOX_night', input_transform = None):
        
        assert which_ds.lower() in ['svox_night', 'svox_overcast', 'svox_rain', 'svox_snow', 'svox_sun']
        print(f'Loading {which_ds} dataset')
        
        self.input_transform = input_transform

        # reference images names
        self.dbImages = np.load(GT_ROOT+f'SVOX/{which_ds}_dbImages.npy')
        
        # query images names
        self.qImages = np.load(GT_ROOT+f'SVOX/{which_ds}_qImages.npy')
        
        # ground truth
        self.ground_truth = np.load(GT_ROOT+f'SVOX/{which_ds}_gt.npy', allow_pickle=True)
        
        # reference images then query images
        self.images = np.concatenate((self.dbImages, self.qImages))
        
        self.num_references = len(self.dbImages)
        self.num_queries = len(self.qImages)
        
    
    def __getitem__(self, index):
        img_path = self.images[index]
        img = Image.open(img_path)
        if self.input_transform:
            img = self.input_transform(img)

        return img, index, img_path

    def __len__(self):
        return len(self.images)