import numpy as np
import os
import pickle
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import torchvision
from torchvision.datasets.utils import list_dir
from torchvision.datasets.vision import VisionDataset

try:
    from VideoClips import VideoClips
except:
    from .VideoClips import VideoClips


def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.
    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)
    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)


def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None, valid_names=None):
    instances = []
    directory = os.path.expanduser(directory)
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x):
            return has_file_allowed_extension(x, extensions)
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    if valid_names is not None and path in valid_names:
                        item = path, class_index
                        instances.append(item)
                    else:
                        item = path, class_index
                        instances.append(item)
    return instances


class Kinetics400(VisionDataset):
    """
    `Kinetics-400 <https://deepmind.com/research/open-source/open-source-datasets/kinetics/>`_
    dataset.
    Kinetics-400 is an action recognition video dataset.
    This dataset consider every video as a collection of video clips of fixed size, specified
    by ``frames_per_clip``, where the step in frames between each clip is given by
    ``step_between_clips``.
    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
    elements will come from video 1, and the next three elements from video 2.
    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
    frames in a video might be present.
    Internally, it uses a VideoClips object to handle clip creation.
    Args:
        root (string): Root directory of the Kinetics-400 Dataset.
        frames_per_clip (int): number of frames in a clip
        step_between_clips (int): number of frames between each clip
        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
            and returns a transformed version.
    Returns:
        video (Tensor[T, H, W, C]): the `T` video frames
        audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
            and `L` is the number of points
        label (int): class of the video clip
    """

    def __init__(
        self, 
        root='/datasets01_101/kinetics/070618/', 
        train=True,
        frames_per_clip=30, 
        step_between_clips=1, 
        frame_rate=None,
        extensions=('avi',), 
        transform=None,
        sound_only_classes=False, 
        _precomputed_metadata=None,
        num_workers=1,
        get_audio=False,
        _video_width=0, 
        _video_height=0,
        _video_min_dimension=0, 
        _audio_samples=0, 
        _audio_channels=0
    ):
        if train:
            root = os.path.join(root, 'train_avi-480p')
        else:
            root = os.path.join(root, 'val_avi-480p')
        super(Kinetics400, self).__init__(root)
        
        if not sound_only_classes:
            mode = 'train' if train else 'val'
            if mode == 'train':
                vid2len_file = f'datasets/data/kinetics_vid_len_train.pkl'  
            else:
                vid2len_file = f'datasets/data/kinetics_vid_len_val.pkl'
            if os.path.exists(vid2len_file):
                with open(vid2len_file, 'rb') as handle:
                    self.vid2len = pickle.load(handle)
            valid_names = set([os.path.join(root, key) for key in self.vid2len])

        if sound_only_classes:
            classes = ["blowing_nose", "blowing_out_candles", "bowling", "chopping_wood", 
                "dribbling_basketball",  "laughing", "mowing_lawn", "playing_accordion", 
                "playing_bagpipes", "playing_bass_guitar", "playing_clarinet", "playing_drums", 
                "playing_guitar", "playing_harmonica", "playing_keyboard", "playing_organ", 
                "playing_piano", "playing_saxophone", "playing_trombone", "playing_trumpet", 
                "playing_violin", "playing_xylophone", "ripping_paper", "shoveling_snow", 
                "shuffling_cards", "singing", "stomping_grapes", "strumming_guitar", 
                "tap_dancing", "tapping guitar", "tapping pen", "tickling"]
            print(f"Length of sound classes: {len(classes)}")
        else:
            classes = list(sorted(list_dir(root)))
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        if not sound_only_classes:
            self.samples = make_dataset(
                self.root, 
                class_to_idx, 
                extensions, 
                is_valid_file=None, 
                valid_names=valid_names,
            )
        else:
            self.samples = make_dataset(
                self.root, 
                class_to_idx,
                extensions, 
                is_valid_file=None, 
                valid_names=None
            )
        self.classes = classes
        self.get_audio = get_audio
    
        # filter out invalid files - no audio
        print(f"Number of files before filter: {len(self.samples)}")
        video_list = [x[0] for x in self.samples]
        
        self.video_clips = VideoClips(
            video_list,
            frames_per_clip,
            step_between_clips,
            frame_rate,
            _precomputed_metadata,
            num_workers=num_workers,
            _video_width=_video_width,
            _video_height=_video_height,
            _video_min_dimension=_video_min_dimension,
            _audio_samples=_audio_samples,
            _audio_channels=_audio_channels,
        )
        self.transform = transform

    @property
    def metadata(self):
        return self.video_clips.metadata

    def __len__(self):
        return self.video_clips.num_clips()

    def __getitem__(self, idx):
        video, audio, info, video_idx = self.video_clips.get_clip(idx, get_audio=self.get_audio)

        label = self.samples[video_idx][1]

        if self.transform is not None:
            video = self.transform(video)
        
        if self.get_audio:
            return video, audio, label, idx, video_idx
        else:
            return video, label, idx, video_idx


if __name__ == '__main__':

    def collate_fn(batch):
        batch = [(d[0], d[2], d[3], d[4]) for d in batch]
        return default_collate(batch)

    import torch

    def to_normalized_float_tensor(vid):
        return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
    
    def resize(vid, size, interpolation='bilinear'):
        # NOTE: using bilinear interpolation because we don't work on minibatches
        # at this level
        scale = None
        if isinstance(size, int):
            scale = float(size) / min(vid.shape[-2:])
            size = None
        return torch.nn.functional.interpolate(
            vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)

    class Resize(object):
        def __init__(self, size):
            self.size = size

        def __call__(self, vid):
            return resize(vid, self.size)
    

    class ToFloatTensorInZeroOne(object):
        def __call__(self, vid):
            return to_normalized_float_tensor(vid)

    transform_train = torchvision.transforms.Compose([
        ToFloatTensorInZeroOne(),
        Resize((112, 112))
    ])

    import time
    val_dataset = Kinetics400(
        train=True,
        sound_only_classes=True,
        get_audio=True,
        transform=transform_train,
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=32, 
        num_workers=0, 
        collate_fn=None,
    )

    tic = time.time()
    for idx, batch in enumerate(val_loader):
        if batch is not None:
            video, audio, label, idx, video_idx = batch
            print(idx, video.size(), audio.size(), label.size(), idx.size(), video_idx.size(), time.time() - tic)
            print(f'Batch time (s): {time.time() - tic}')
            break