import pandas as pd
import pickle
import numpy as np
from PIL import Image
from pathlib import Path
from torchvision import transforms
from torch.utils.data import Dataset

class BaseDataset(Dataset):
    def __init__(self, metadata, label_names, conversion, transform=None, image_size=224, cache=True, cache_dir=''):
        super(Dataset, self).__init__()

        self.metadata = metadata
        self.label_names = label_names
        self.conversion = conversion
        self.transform = transform
        self.image_size = image_size
        self.cache = cache
        if cache:
            self.cache_dir = Path(cache_dir)

        self.n_files = len(metadata['path'])

    def get_cache_path(self, idx):
        item = self.metadata.iloc[idx, :]
        path = Path(item['path'])
        return (self.cache_dir / '_'.join(path.parts[-3:])).with_suffix('.pkl')

    def ensure_3dim(self, img):
        if len(img.size)==2:
            img = img.convert('RGB')
        return img

    def __getitem__(self, idx):
        item = self.metadata.iloc[idx].to_dict()
        cache_path = self.get_cache_path(idx) if self.cache else None
        
        if self.cache and cache_path and cache_path.is_file():
            img, item = pickle.load(cache_path.open('rb'))
        else:
            img = self.ensure_3dim(Image.open(self.image_list[idx][0]))
            img = transforms.Compose([
                        transforms.Resize([self.image_size, self.image_size])
                        ])(img)

            if self.cache:
                pickle.dump((img, item), cache_path.open('wb'))

        label = self.image_list[idx][-1]

        if self.transform is not None: # apply image augmentations after caching
            img = self.transform(img)
        
        return {
                'labels': label, 
                'x':img, 
                'idx': idx
            }
    
    def __len__(self):
        return self.n_files

class MultiClassDataset(BaseDataset):
    def __init__(self, metadata, label_names, image_dict, conversion, transform=None, image_size=224, cache=True, cache_dir=''):
        super(MultiClassDataset, self).__init__(metadata, label_names, conversion, transform, image_size, cache, cache_dir)

        self.image_dict = image_dict

        self.init_setup()

        ##### Re-indexing metadata to match image_list
        if hasattr(self, 'image_list'):
            self.metadata['path'] = pd.Categorical(self.metadata['path'], [x[0] for x in self.image_list])
            self.metadata = self.metadata.sort_values('path')

    def init_setup(self):        
        self.n_files       = np.sum([len(self.image_dict[key]) for key in self.image_dict.keys()])
        self.avail_classes = sorted(list(self.image_dict.keys()))

        counter = 0
        temp_image_dict = {}
        for i,key in enumerate(self.avail_classes):
            temp_image_dict[key] = []
            for path in self.image_dict[key]:
                temp_image_dict[key].append([path, counter])
                counter += 1

        self.image_dict = temp_image_dict
        self.image_list = [[(x[0],key) for x in self.image_dict[key]] for key in self.image_dict.keys()]
        self.image_list = [x for y in self.image_list for x in y]

        self.image_paths = self.image_list

        self.is_init = True

class MultiLabelDataset(BaseDataset):
    def __init__(self, metadata, label_names, conversion=None, transform=None, image_size=224, cache=True, cache_dir=''):
        super(MultiLabelDataset, self).__init__(metadata, label_names, conversion, transform, image_size, cache, cache_dir)

        self.init_setup()

    def init_setup(self):
        self.n_files       = len(self.metadata['path'])
        self.avail_classes = list(range(len(self.label_names)))

        index_dict = {key: np.where(self.metadata[name].values)[0] for key, name in self.conversion.items()}
        self.image_dict = {key: list(zip(self.metadata['path'].values[index_list], index_list)) for key, index_list in index_dict.items()} ## of form label #: (path, idx)
        self.image_list = list(zip(self.metadata['path'].values, self.metadata[self.label_names].values)) ## of form (path, multi-hot(label))

        self.image_paths = self.image_list

        self.is_init = True
