import os
from PIL import Image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


class EuroSATDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.transform = transform
        self.split = split

        self.data_root = os.path.join(self.root, "2750")
        self.classes = sorted(os.listdir(self.data_root))
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        self.data = []
        self.targets = []

        for (i, cls_name) in enumerate(self.classes):
            class_dir = os.path.join(self.data_root, cls_name)
            images = sorted(os.listdir(class_dir))
            n = len(images)
            indices = list(range(n))
            train_idx, test_idx = train_test_split(indices, test_size=0.2, random_state=42)

            if split == 'train':
                selected_indices = train_idx
            else:
                selected_indices = test_idx

            for j in selected_indices:
                img_path = os.path.join(cls_name, images[j])
                self.data.append(img_path)
                self.targets.append(i)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = os.path.join(self.root, "2750", self.data[index])
        img = Image.open(img_path).convert('RGB')
        target = self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        return img, target
