import logging
import os
from pathlib import Path
import pickle
import sys
import tempfile
from typing import Optional
from zipfile import ZipFile, is_zipfile

import numpy
from torch.utils.data import Dataset

import av


class RAVDESSStillsDataset(Dataset):

    emotions = {
        0: 'neutral',
        1: 'calm',
        2: 'happiness',
        3: 'sadness',
        4: 'anger',
        5: 'fear',
        6: 'disgust',
        7: 'surprise'
    }

    def __init__(
        self,
        root_dir:str,
        work_dir:str=tempfile.gettempdir(),
    ):
        '''
        RAVDESS file form
        width = 1280
        height = 720
        channels = 3

        This dataset relies on a sample utility (and apparently temporary) from
        [Torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/video_utils.py).

        Args:
            root_dir: Directory with a tree of RAVDESS archives (ZIP archives). (expands ~ if any)
            work_dir: A directory where videos will be extracted to.
        '''
        ds_cache = os.path.join(work_dir, f'ravdess_stills_dataset_precompute.bin')
        if os.path.exists(ds_cache):
            logging.info("Restore dataset from cache")
            with open(ds_cache, 'rb') as cache:
                self.examples = pickle.load(cache)
        else:
            logging.info("Prepare and cache dataset...")
            self.examples = []
            for r, ds, fs in os.walk(os.path.expanduser(root_dir)):
                for f in fs:
                    path = os.path.join(r, f)
                    if is_zipfile(path):
                        base, _ = os.path.splitext(os.path.basename(path))
                        dest_base = os.path.join(work_dir, base)
                        with ZipFile(path) as archive:
                            for video in archive.namelist():
                                if not os.path.basename(video).startswith('01'):
                                    logging.debug(f'Skipping non-AV file ({video})')
                                    continue
                                dest = os.path.join(dest_base, video)
                                Path(os.path.dirname(dest)).mkdir(parents=True, exist_ok=True)
                                if os.path.exists(dest):
                                    logging.debug(f"Skipping already existing file: {dest}")
                                else:
                                    with open(dest, 'wb') as target:
                                        with archive.open(video, 'r') as v:
                                            target.write(v.read())
                                class_code = self._name_to_emotion(os.path.basename(video))
                                with av.open(dest) as container:
                                    stream = container.streams.video[0]
                                    fps = int(stream.framerate)
                                    for idx, frame in enumerate(container.decode(stream)):
                                        if idx % fps == 0:
                                            img = frame.to_image().resize((320, 180))
                                            self.examples.append((numpy.array(img).transpose((2, 0, 1)), class_code))
            logging.info(f"Dataset prepared in {work_dir}")
            with open(ds_cache, 'wb') as cache:
                pickle.dump(self.examples, cache)

        if self.__len__() > 0:
            logging.info(f"Dataset ready for use with Torch, cached in {ds_cache}")
        else:
            logging.error(f"Dataset is empty. Will not return any data")


    def __len__(self) -> int:
        return len(self.examples)


    def _name_to_emotion(self, name:str) -> int:
        try:
            code = int(name.split('-')[2][1]) - 1
            assert 0 <= code < len(self.emotions), f"Wrong parsing of RAVDESS file name: {name}. The 3rd block should be a number between 1 and {len(self.emotions)}"
            return code
        except IndexError:
            raise ArgumentError(f"Invalid file name, cannot extract emotion class from {name}")


    def __getitem__(self, idx:int) -> Optional[dict]:
        still, label = self.examples[idx]
        return {
            "still": still,
            "label": label,
        }


    def class_count(self):
        return len(self.emotions)

    def label_for(self, code):
        return self.emotions[code]
