import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T


class CustomDataset(Dataset):
    def __init__(self, txt_file, root_dir, transform=None):
        self.image_list = []
        self.id_list = []
        self.root_dir = root_dir
        self.transform = transform
        with open(txt_file, 'r') as f:
            line = f.readline()
            # self.datas = f.readlines()
            while line:
                img_name = line.split()[0]
                label = int(line.split()[1])
                self.image_list.append(img_name)
                self.id_list.append(label)
                line = f.readline()
        
    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label = self.id_list[idx]
        img_name = os.path.join(self.root_dir, img_name)
        image = Image.open(img_name).convert('RGB')

        if self.transform and image:
            image = self.transform(image)
        return image, label
    

def galaxy_sdss(args, scale_size=256, target_size=224):
    path = "./Astro_DataSet/galaxy-sdss/"
    train_txt = "./Dataset/galaxy_sdss_train.txt"
    test_txt = "./Dataset/galaxy_sdss_test.txt"

    train_transforms = T.Compose([
        T.RandomResizedCrop(target_size, interpolation=3),
        T.RandomHorizontalFlip(0.5),
        T.ToTensor(),
        T.Normalize(mean=[0.1689, 0.1536, 0.1516], std=[0.1284, 0.0963, 0.1051])
    ])

    test_transforms = T.Compose([
        T.Resize(scale_size, interpolation=3),
        T.CenterCrop(target_size),
        T.ToTensor(),
        T.Normalize(mean=[0.1689, 0.1536, 0.1516], std=[0.1284, 0.0963, 0.1051])
    ])

    train_data = CustomDataset(txt_file=train_txt, root_dir=path, transform=train_transforms)
    test_data = CustomDataset(txt_file=test_txt, root_dir=path, transform=test_transforms)
    num_class = 5

    return train_data, test_data, num_class
