import logging
import os
from pathlib import Path
import pickle
import sys
import tempfile
from typing import Optional
from zipfile import ZipFile, is_zipfile

import numpy as np
from torch.utils.data import Dataset

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from loaders.ravdess_stills import RAVDESSStillsDataset
from utils.video import VideoClips


class RAVDESSVideoDataset(RAVDESSStillsDataset):

    # Emotion from sound requires at least 200ms (TODO retrieve ref)
    #   After trials, there are many samples that get clipped to a little
    #   less than 200ms when the value is used to slice the streams.
    #   So I arbitrarily use 300ms to be above 200ms and guarantee
    #   we get enough data for emotion recognition on audio.
    #   (Also because it matches well RAVDESS FPS)
    MIN_CLIP_DURATION = 0.3 # seconds.

    RAVDESS_FPS = 30

    def __init__(
        self,
        root_dir:str,
        work_dir:str=tempfile.gettempdir(),
        vsize:list=[86, 48],
    ):
        '''
        RAVDESS file form
        width = 1280
        height = 720
        channels = 3

        This dataset relies on a sample utility (and apparently temporary) from
        [Torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/video_utils.py).

        Args:
            root_dir: Directory with a tree of RAVDESS archives (ZIP archives). (expands ~ if any)
            work_dir: A directory where videos will be extracted to.
        '''
        self.target_vsize = tuple(vsize)

        ds_cache = os.path.join(work_dir, 'ravdess_dataset_precompute.bin')
        if os.path.exists(ds_cache):
            logging.info("Restore dataset from cache")
            with open(ds_cache, 'rb') as cache:
                data = pickle.load(cache)
                self.dataset = data['dataset']
                self.classes = data['classes']
        else:
            logging.info("Prepare and cache dataset...")
            videos = []
            self.classes = []
            for r, ds, fs in os.walk(os.path.expanduser(root_dir)):
                for f in fs:
                    path = os.path.join(r, f)
                    if is_zipfile(path):
                        base, _ = os.path.splitext(os.path.basename(path))
                        dest_base = os.path.join(work_dir, base)
                        with ZipFile(path) as archive:
                            for video in archive.namelist():
                                if not os.path.basename(video).startswith('01'):
                                    logging.debug(f'Skipping non-AV file ({video})')
                                    continue
                                dest = os.path.join(dest_base, video)
                                Path(os.path.dirname(dest)).mkdir(parents=True, exist_ok=True)
                                if os.path.exists(dest):
                                    logging.debug(f"Skipping already existing file: {dest}")
                                else:
                                    with open(dest, 'wb') as target:
                                        with archive.open(video, 'r') as v:
                                            target.write(v.read())
                                videos.append(dest)
                                self.classes.append(self._name_to_emotion(os.path.basename(video)))
            logging.info(f"Dataset unpacked in {work_dir}")
            frames_per_clip = int(self.MIN_CLIP_DURATION * self.RAVDESS_FPS)
            self.dataset = VideoClips(videos, clip_length_in_frames=frames_per_clip)
            with open(ds_cache, 'wb') as cache:
                data = {
                    'dataset': self.dataset,
                    'classes': self.classes,
                }
                pickle.dump(data, cache)

        if self.__len__() > 0:
            logging.info(f"Dataset ready for use with Torch, cached in {ds_cache}")
        else:
            logging.error(f"Dataset is empty. Will not return any data")


    def __len__(self) -> int:
        return self.dataset.num_clips()


    def __getitem__(self, idx:int) -> Optional[dict]:
        video, audio, _, class_idx = self.dataset.get_clip(idx, vsize=self.target_vsize)
        if video is None or audio is None:
            return None
        return {
            "video": video,
            "audio": audio,
            "label": self.classes[class_idx],
        }


if __name__ == '__main__':
    from PIL import Image
    ds = RAVDESSVideoDataset(sys.argv[1])
    debug_dir = './debug'
    Path(debug_dir).mkdir(parents=True, exist_ok=True)
    for item in range(int(ds.__len__() / 2)):
        if item > int(ds.__len__() / 2 - 10):
            data = ds[item]
            print(f'#{item}: V {data["video"].shape}; A {data["audio"].shape}')
            for frame_idx in range(data["video"].shape[1]):
                frame = data["video"].select(1, frame_idx)
                Image.fromarray(frame.numpy().transpose((1,2,0))).save(os.path.join(debug_dir, f'{item}-{frame_idx}-{data["label"]}.jpg'))
