import os
import pickle

from torch.utils.data import Dataset


class ShapesDataset(Dataset):

    def __init__(
        self,
        archive:str,
        return_stills:bool=False,
        return_video_shape:tuple=(0, 1, 2, 3),
        return_audio_shape:tuple=(0, 1, 2),
    ):
        '''

        Args:
            archive: File containing the shape dataset, generated by evaluation/bin/gen_shape_dataset. (expands ~ if any)
            return_stills: Split videos into still frames. Useful when training a simple classifier.
            return_video_shape: Shape of the output video samples. Defaults to their "natural" shape.
            return_audio_shape: Shape of the output audio samples. Defaults to their "natural" shape.
        '''
        self.return_stills = return_stills
        self.video_shape = tuple(return_video_shape)
        self.audio_shape = tuple(return_audio_shape)

        if os.path.exists(archive):
            with open(archive, 'rb') as data:
                self.dataset, self.label_map = pickle.load(data)
        else:
            raise Exception('Invalid shape archive. I cannot continue.')

        self.cache = {}
        if self.return_stills:
            self.offset = len(self.dataset[0]['video'])
            self.length = len(self.dataset) * self.offset
        else:
            self.length = len(self.dataset)


    def __len__(self) -> int:
        return self.length


    def __getitem__(self, idx:int) -> dict:
        if self.return_stills:
            img_set_idx = idx // self.offset
            img_idx = idx - img_set_idx * self.offset
            assert img_idx < self.offset
            if img_set_idx in self.cache and img_idx in self.cache[img_set_idx]:
                return self.cache[img_set_idx][img_idx]
            else:
                known = self.dataset[img_set_idx]
                data = {
                    'image': known['video'][img_idx].copy(),
                    'label': known['label'][img_idx].copy(),
                }
                try:
                    self.cache[img_set_idx][img_idx] = data
                except:
                    self.cache[img_set_idx] = {}
                    self.cache[img_set_idx][img_idx] = data
                return data
        else:
            if idx in self.cache:
                return self.cache[idx]
            else:
                original = self.dataset[idx]
                data = {
                    'video': original['video'].transpose(self.video_shape),
                    'audio': original['audio'].transpose(self.audio_shape),
                    'label': original['label'],
                }
                self.cache[idx] = data
                return data


    def label_for(self, num:int) -> str:
        return self.label_map[num]


    def class_count(self) -> int:
        return len(self.label_map)
