import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
import os.path as osp
from pathlib import Path
import random
import pandas as pd
from PIL import Image
from torchvision import transforms
import scipy.io as scio
import numpy as np

SUN_PATH = 'path/to/your/dataset/'
class SUNDataset(Dataset):
    def __init__(self, image_dir:str=SUN_PATH, image_split:str='Training_01.txt',ratio:int=1,transform:transforms=None):
        """
        Args:
            image_dir (str): image path, 
            image_split: Training_01.txt/Testing_01.txt
            transform (callable, optional)
            ratio (int): label fraction
        """
        self.transform = transform
        self.image_dir = image_dir
        self.transform = transform
        self.image_list = []
        with open(osp.join(self.image_dir,'splits',image_split)) as f:
            for r in f.readlines():
                self.image_list.append(r.split('\n')[0])

        # mapping labels
        self.label_list = [str(Path(i).parents[0]) for i in self.image_list]
        self.label_name = []
        with open(osp.join(self.image_dir,'splits','ClassName.txt')) as f:
            for r in f.readlines():
                self.label_name.append(r.split('\n')[0])
        mapping = {v: i for i, v in enumerate(self.label_name)}
        self.label_list = [mapping[str(Path(i).parents[0])] for i in self.image_list]
        # image paths
        self.image_list = [osp.join(self.image_dir,'SUN397',i[1:]) for i in self.image_list]
        

        # few shot
        df = pd.DataFrame({'path': self.image_list, 'label': self.label_list})
        self.sampled = df.groupby('label').sample(frac=ratio, random_state=42)

    def __len__(self):
        return self.sampled.shape[0]

    def __getitem__(self, idx):
        file_path = self.sampled.iloc[idx,0]
        label = self.sampled.iloc[idx,1]
        image = Image.open(file_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label




from enum import Enum
class DATASPLIT(Enum):
    TRAIN=1
    VAL=2
    TEST=3

DTD_PATH = 'path/to/your/dataset/'
class DTDDataset(Dataset):
    def __init__(self, image_dir:str=DTD_PATH, image_split:str='TRAIN',ratio:int=1,transform:transforms=None):
        """
        Args:
            image_dir (str): image path, 
            image_split: TRAIN/VAL/TEST,
            transform (callable, optional)
            ratio (int): label fraction
        """
        self.transform = transform
        self.image_dir = image_dir
        self.transform = transform
        self.image_list = []
        mat = scio.loadmat(osp.join(self.image_dir,'imdb','imdb.mat'))
        split_label_all = mat['images'][0][0][2][0]
        label_all = mat['images'][0][0][3][0]
        for i in  mat['images'][0][0][1][0]:
            self.image_list.append(i[0])

        split_label = getattr(DATASPLIT, image_split).value
        self.image_list = np.array(self.image_list)[split_label_all==split_label]
        self.image_list = [osp.join(self.image_dir,'images',i) for i in self.image_list]
        self.label_list = label_all[split_label_all==split_label]-1
        # few shot
        df = pd.DataFrame({'path': self.image_list, 'label': self.label_list})
        self.sampled = df.groupby('label').sample(frac=ratio, random_state=42)

    def __len__(self):
        return self.sampled.shape[0]

    def __getitem__(self, idx):
        file_path = self.sampled.iloc[idx,0]
        label = self.sampled.iloc[idx,1]
        image = Image.open(file_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


if __name__ == '__main__':
    
    transform = transforms.Compose([
                    transforms.Resize((224, 224)),  
                    transforms.ToTensor(),          
                    transforms.Normalize(mean=[0.5], std=[0.5])
                    ])
    aa=DTDDataset(transform=transform)
    for image,label in aa:
        print(image.shape,label)