import os
import logging
import traceback
import copy
import random
import time
from scipy import signal

import numpy as np
import torch.utils.data
import torchvision
from .video_transforms import VisualizeCrop

from PIL import Image
from decord import VideoReader, AudioReader

logger = logging.getLogger(__name__)


class SamplingDataset(torch.utils.data.Dataset):
    """
    This Datasetmodule would load video item one by one
    Takes overlapped data for sampling event at the end or at the start of video
    Takes no_augmentation clip to store at the Buffer
    """
    _MAX_CONSECUTIVE_FAILURES = 10

    def __init__(
            self,
            video_info,
            transform,
            num_clips: int = 1,
            num_frames: int = 4,
            video_duration: float = 1.0,
            audio_duration: float = 2.0,
            decode_audio: bool = True,
            sample_type: str = 'audio',
    ) -> None:

        self._video_info = video_info
        self._decode_audio = decode_audio
        self._transform = transform
        self._num_clips = num_clips
        self._num_frames = num_frames
        # Sample duration
        self._video_duration = video_duration
        self._audio_duration = audio_duration
        # Decide how to selectively sample from given video online
        assert sample_type in ['random', 'middle']
        self.sample_type = sample_type
        self.overlap = self._audio_duration / 2.0
        # Return PIL.Images list and raw audio for visualization

    @property
    def num_videos(self):
        """
        Returns:
            Number of videos in dataset.
        """
        return len(self._video_info)

    def time_to_indices(self, video_reader, time):
        times = video_reader.get_frame_timestamp(range(len(video_reader))).mean(-1)
        indices = np.searchsorted(times, time)
        for idx, index in enumerate(indices):
            if index == len(video_reader):
                indices[idx] = indices[idx] - 1
        # Use `np.bitwise_or` so it works both with scalars and numpy arrays.
        return np.where(np.bitwise_or(indices == 0, times[indices] - time <= time - times[indices - 1]), indices,
                        indices - 1)

    def load_video_clip(self, video, video_start, video_end):
        # Decode video data
        video_clip = None
        start, end = self.time_to_indices(video, [video_start, video_end])
        end = min(len(video) - 1, end)
        start = min(start, end - 1)
        downsample_indices = np.linspace(start, end, self._num_frames, endpoint=False).astype(np.int)
        video_data = video.get_batch(downsample_indices).asnumpy()
        video_clip = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(downsample_indices)]

        if video_clip is None:
            logger.debug(
                "Failed to load clip {};".format(video.name)
            )
            return None

        return video_clip

    def load_audio_clip(self, audio, audio_start, audio_end):
        audio_clip = audio._array.mean(0)[int(audio_start*16000):int(audio_end*16000)]
        if audio_clip is None:
            logger.debug(
                "Failed to load clip {}".format(audio.name)
            )
            return None

        return audio_clip

    def __getitem__(self, video_index):
        ret = dict()
        ret.update(self._get_video_audio(video_index, self._decode_audio, self.sample_type))

        return ret

    def __len__(self):
        return len(self._video_info)

    def _get_video_audio(self, index, decode_audio, sample_type):
        video_path, info_dict = self._video_info[index]
        video = None
        audio = None
        try:
            sample_dict = {
                **info_dict,
                "video_data": [],
            }
            if decode_audio:  # Note that some clips do not contain audio, hence do this first.
                sample_dict.update({"audio_data": []})
                audio_path = video_path.split('.')[0] + '.mp3'
                audio = AudioReader(audio_path)
                assert audio._array.mean() != 0, 'no audio in mp3 file'

            video = VideoReader(video_path)

            framerate = video.get_avg_fps()
            video_len = len(video) / framerate

            if sample_type == 'random':
                event = np.arange(self.overlap, video_len - self.overlap, step=0.1)

            # For debugging purpose, always sample middle
            elif sample_type == 'middle':
                event = np.array([video_len / 2])

            else:
                raise ValueError(f'Undefined sampling type {sample_type}')

            # If there is no events, move to another video clip
            assert len(event) != 0

            num_clips = min(self._num_clips, len(event))
            sampled_events = np.random.choice(event, num_clips, replace=False)  # without replacement
            sampled_events.sort()

            for event_time in sampled_events:
                caption_time = event_time
                # Video
                video_clip = self.load_video_clip(video,
                                                  event_time - self._video_duration / 2.,
                                                  event_time + self._video_duration / 2.,
                                                  )
                clips = {'video_data': video_clip}

                # Audio
                if decode_audio:
                    audio_clip = self.load_audio_clip(audio,
                                                      event_time - self._audio_duration / 2.,
                                                      event_time + self._audio_duration / 2.,
                                                      )
                    clips.update({'audio_data': audio_clip.astype(np.float32)})

                # Transform
                clips = self._transform(clips)

                if self.no_aug_sample:
                    sample_dict["video_data"].append(clips['video_data'][0])
                    sample_dict["no_aug_video_data"].append(clips['video_data'][1])

                    if decode_audio:
                        sample_dict["audio_data"].append(clips['audio_data'][0])
                        sample_dict["no_aug_audio_data"].append(clips['audio_data'][1])
                else:
                    sample_dict["video_data"].append(clips['video_data'])

                    if decode_audio:
                        sample_dict["audio_data"].append(clips['audio_data'])

            del video
            if decode_audio:
                del audio

            return sample_dict

        except Exception as e:
            if video is not None:
                del video
            if audio is not None:
                del audio
            video_index = random.sample(range(len(self)), k=1)[0]
            return self._get_video_audio(video_index, decode_audio, sample_type)

    def collate(self, batch):
        keys = set([key for b in batch for key in b.keys()])
        video_keys = set([k for k in keys if "video" in k])
        audio_keys = set([k for k in keys if "audio" in k])
        other_keys = keys - video_keys - audio_keys

        # Change list formed data into tensor, extend batch size if more than one data in sample
        new_batch = []
        for sample in batch:
            while len(sample['video_data']) != 0:
                copied_dict = {k: sample[k] if k in other_keys else sample[k].pop() for k in keys}
                new_batch.append(copied_dict)

        batch_size = len(new_batch)
        dict_batch = {k: [dic[k] if k in dic else None for dic in new_batch] for k in keys}

        # If it failed to extract event in a clip, video_data would be empty list, thus skipped.
        if batch_size == 0:
            return dict_batch

        video_sizes = list()
        for video_key in video_keys:
            video = dict_batch[video_key]
            video_sizes += [video[0].shape]
        for size in video_sizes:
            assert (
                    len(size) == 4
            ), f"Collate error, an video should be in shape of (T, 3, H, W), instead of given {size}"
        if len(video_keys) != 0:
            max_video_length = self._num_frames
            max_height = max([i[2] for i in video_sizes])
            max_width = max([i[3] for i in video_sizes])
        for video_key in video_keys:
            video = dict_batch[video_key]
            new_videos = torch.ones(batch_size, 3, max_video_length, max_height, max_width) * -1.0
            for bi in range(batch_size):
                orig_batch = video[bi]
                if orig_batch is None:
                    new_videos[bi] = None
                else:
                    orig = video[bi]
                    new_videos[bi, : orig.shape[0], :, : orig.shape[2], : orig.shape[3]] = orig
            dict_batch[video_key] = new_videos

        audio_sizes = list()
        for audio_key in audio_keys:
            audio = dict_batch[audio_key]
            for audio_i in audio:
                audio_sizes += [audio_i.shape]
        for size in audio_sizes:
            assert (
                    len(size) == 3
            ), f"Collate error, an audio should be in shape of (1, H, W), instead of given {size}"
        if len(audio_keys) != 0:
            max_height = max([i[1] for i in audio_sizes])
            max_width = max([i[2] for i in audio_sizes])

        for audio_key in audio_keys:
            audio = dict_batch[audio_key]
            new_audios = torch.ones(batch_size, 1, max_height, max_width) * -1.0
            for bi in range(batch_size):
                orig_batch = audio[bi]
                if orig_batch is None:
                    new_audios[bi] = None
                else:
                    orig = audio[bi]
                    new_audios[bi, : orig.shape[0], : orig.shape[1], : orig.shape[2]] = orig
            dict_batch[audio_key] = new_audios

        label_keys = [k for k in list(dict_batch.keys()) if "label" in k]
        for label_key in label_keys:
            label = dict_batch[label_key]
            new_labels = torch.ones(batch_size, dtype=torch.long)
            for bi in range(batch_size):
                orig_batch = label[bi]
                if orig_batch is None:
                    new_labels[bi] = None
                else:
                    orig = label[bi]
                    new_labels[bi] = orig
            dict_batch[label_key] = new_labels

        return dict_batch