from pathlib import Path
import numpy as np
from PIL import Image
from torch.utils.data import Dataset


DATASET_ROOT = '/database/dkim/VPR_datasets/'
GT_ROOT = '/home/dkim/VPR/salad/datasets/'

path_obj = Path(DATASET_ROOT)
if not path_obj.exists():
    raise Exception(f'Please make sure the path {DATASET_ROOT} to SPED dataset is correct')

if not path_obj.joinpath('ref') or not path_obj.joinpath('query'):
    raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}')

# PYRA
class AmstertimeDataset(Dataset):
    def __init__(self, input_transform=None):
        self.input_transform = input_transform
        
        # reference images names
        self.dbImages = np.load(Path(GT_ROOT) / "Amstertime/Amstertime_dbImages.npy")
        
        # query images names
        self.qImages = np.load(Path(GT_ROOT) / "Amstertime/Amstertime_qImages.npy")

        # ground truth
        self.ground_truth = np.load(Path(GT_ROOT) / "Amstertime/Amstertime_gt.npy", allow_pickle=True)
        
        # reference images then query images
        self.images = np.concatenate((self.dbImages, self.qImages))
        self.num_references = len(self.dbImages)
        self.num_queries = len(self.qImages)

    
    def __getitem__(self, index):
        # 이미지 경로 생성 (DATASET_ROOT와 npy에 저장된 상대 경로 결합)
        img_path = DATASET_ROOT + 'Amstertime/' + self.images[index]
        if 'ref/' in str(img_path):
            img = Image.open(img_path)

            if self.input_transform:
                img = self.input_transform(img)

        elif 'query/' in str(img_path):
            img = Image.open(img_path)
            img = img.convert('RGB')

            if self.input_transform:
                img = self.input_transform(img)
        else:
            raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}')
        
        return img, index, img_path
    
    def __len__(self):
        return len(self.images)