import os
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset

# npy 파일이 저장된 경로 (GT_ROOT)
GT_ROOT = '/home/dkim/VPR/salad/datasets/'

class EynshamDataset(Dataset):
    def __init__(self, input_transform=None):
        self.input_transform = input_transform

        # npy 파일 로드 (database 이미지, query 이미지, ground truth)
        self.dbImages = np.load(os.path.join(GT_ROOT, 'Eynsham/Eynsham_dbImages.npy'))
        self.qImages = np.load(os.path.join(GT_ROOT, 'Eynsham/Eynsham_qImages.npy'))
        self.ground_truth = np.load(os.path.join(GT_ROOT, 'Eynsham/Eynsham_gt.npy'), allow_pickle=True)
        
        # 전체 이미지 경로: database 이미지 후 query 이미지
        self.images = np.concatenate((self.dbImages, self.qImages))
        
        self.num_references = len(self.dbImages)
        self.num_queries = len(self.qImages)

    def __getitem__(self, index):
        img_path = self.images[index]
        img = Image.open(img_path).convert('RGB')
        if self.input_transform:
            img = self.input_transform(img)
        return img, index, img_path

    def __len__(self):
        return len(self.images)
