from scipy import rand
from torch.utils.data import Dataset
import torch
import torchvision.transforms as transforms
import data_generate.transformations as transfm

from PIL import Image
from tqdm import tqdm
from itertools import chain
import pandas as pd
import glob
import random
import time
import os

class FewShotImageDataset(Dataset):
    def __init__(self, task_list, img_lvl=1, transform=None, device=None, task_name=None, verbose='dataset'):
        self.task_list = task_list
        self.img_lvl = img_lvl
        self.transform = transform
        self.device = device
        self.task_name = task_name
        self.verbose = verbose

        self.df = self.generate_task_df()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        
        image = self.df.loc[index, 'cuda_tensor']
        label = torch.tensor(self.df.loc[index, 'cls_lbl'], device=self.device)

        return image, label

    def generate_task_df(self):
        img_dict_list =[]

        for idx, classdir in (tqdm(enumerate(self.task_list), desc='Generating {}'.format(self.verbose),
                    total=len(self.task_list)) if self.verbose is not None else enumerate(self.task_list)):
            
            # choose the max num of each class
            if self.task_name == 'cu_birds':
                max_num = 60
            elif self.task_name == 'texture':
                max_num = 120
            elif self.task_name == 'aircraft':
                max_num = 100
            elif self.task_name == 'fungi':
                max_num = 150
            

            # list all imgs in the class
            img_path_list = glob.glob(classdir + '/*' * self.img_lvl)
            # img_path_list = random.sample(img_path_list, max_num)
            
            for img_path in img_path_list:

                image = Image.open(img_path)
                
                # check if non-grayscale images are in RGB mode. There is one image with four channel in mini-imagenet
                if ((image.mode == '1' or image.mode == 'L')
                    and sum(transfm.__class__.__name__ == 'Grayscale' for transfm in self.transform.transforms) == 0) \
                        or (image.mode not in ('RGB', '1', 'L')):
                    image = image.convert(mode='RGB')
                if self.transform is not None:
                    image = self.transform(image)
                
                if torch.isnan(image).any():
                    continue

                path_split = self.split_path(img_path)
                
                # add info to the dict
                img_dict_list.append({
                    'cls_name': path_split[-2],
                    'cls_lbl': idx,
                    'img_path': img_path,
                    'cuda_tensor': image
                })

        
        df_task = pd.DataFrame(img_dict_list)
        return df_task
    
    def relbl_df(self):
        for ind, value in enumerate(self.relabel[1]):
            self.df.loc[self.df[self.relabel[0]] == value, 'cls_lbl'] = ind

        
    
    def split_path(self, path):
        # remove trailing '/' if any
        split_ls = os.path.normpath(path).split('/')
        if '' in split_ls:
            split_ls.remove('')
        if '.' in split_ls:
            split_ls.remove('.')
        return split_ls