﻿# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import numpy as np
import torch

try:
    import pyspng
except ImportError:
    pyspng = None

# from lrw_dataset import LRWDataset
from torch.utils.data import DataLoader

import os
import random

import torch
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

class StatefulRandomHorizontalFlip():
    def __init__(self, probability=0.5):
        self.probability = probability
        self.rand = random.random()

    def __call__(self, img):
        if self.rand < self.probability:
            return F.hflip(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(probability={})'.format(self.probability)


def build_word_list(directory, num_words, seed):
    random.seed(seed)
    words = os.listdir(directory)
    words.sort()
    random.shuffle(words)
    words = words[:num_words]
    return words


class LRWDataset(Dataset):
    def __init__(self, 
                path, 
                pose_path,
                num_words=10, 
                in_channels=1, 
                mode="train", 
                augmentations=False, 
                estimate_pose=False, 
                seed=42, 
                query=None,
                crop_size=(112,112),
                reshape_size=(112,112)):
        self.pose_path = pose_path
        self.seed = seed
        self.num_words = num_words
        self.in_channels = in_channels
        self.query = query
        self.augmentation = augmentations if mode == 'train' else False
        self.poses = None
        if estimate_pose == False:
            self.poses = self.head_poses(mode, query)
        self.video_paths, self.files, self.labels, self.words = self.build_file_list(path, mode)
        self.estimate_pose = estimate_pose
        self.crop_size = crop_size
        self.reshape_size = reshape_size

    def head_poses(self, mode, query):
        poses = {}
        yaw_file = open(f'{self.pose_path}/{mode}.txt', "r")
        content = yaw_file.read()
        for line in content.splitlines():
            file, yaw = line.split(",")
            yaw = float(yaw)
            if query == None or (query[0] <= yaw and query[1] > yaw):
                poses[file] = yaw
        return poses

    def build_file_list(self, directory, mode):
        words = build_word_list(directory, self.num_words, seed=self.seed)
        print(words)
        paths = []
        file_list = []
        labels = []
        for i, word in enumerate(words):
            dirpath = directory + "/{}/{}".format(word, mode)
            files = os.listdir(dirpath)
            for file in files:
                if file.endswith("mp4"):
                    if self.poses != None and file not in self.poses:
                        continue
                    path = dirpath + "/{}".format(file)
                    file_list.append(file)
                    paths.append(path)
                    labels.append(i)

        return paths, file_list, labels, words

    def build_tensor(self, frames):
        temporalVolume = torch.FloatTensor(29, self.in_channels, self.reshape_size[0], self.reshape_size[1])
        if(self.augmentation):
            augmentations = transforms.Compose([
                StatefulRandomHorizontalFlip(0.5),
            ])
        else:
            augmentations = transforms.Compose([])

        if self.in_channels == 1:
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Normalize([0.4161, ], [0.1688, ]),
            ])
        elif self.in_channels == 3:
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.CenterCrop(self.crop_size),
                transforms.Resize(self.reshape_size),
                augmentations,
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

        for i in range(0, 29):
            frame = frames[i].permute(2, 0, 1)  # (C, H, W)
            temporalVolume[i] = transform(frame)

        temporalVolume = temporalVolume.transpose(1, 0)  # (C, D, H, W)
        return temporalVolume

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        file = self.files[idx]
        video, _, _ = torchvision.io.read_video(self.video_paths[idx], pts_unit='sec')  # (Tensor[T, H, W, C])
        if self.estimate_pose:
            angle_frame = video[14].permute(2, 0, 1)
        else:
            angle_frame = 0
        frames = self.build_tensor(video)
        if self.estimate_pose:
            yaw = 0
        else:
            yaw = self.poses[file]

        sample = {
            'frames': frames,
            'label': torch.LongTensor([label]),
            'word': self.words[label],
            'file': self.files[idx],
            'yaw': torch.FloatTensor([yaw]),
            'angle_frame': angle_frame,
        }
        return sample
    
if __name__ == '__main__':
    dataset = LRWDataset('your dataset path', 
                        'your meta path',
                        num_words=10, 
                        in_channels=3, 
                        mode="train", 
                        augmentations=False, 
                        estimate_pose=False, 
                        seed=42, 
                        query=None,
                        crop_size=(256,256),
                        reshape_size=(256,256))
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    for item in loader:
        for key in item.keys():
            try:
                print(key, item[key].shape)
            except:
                pass
        break


#----------------------------------------------------------------------------

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
        name,                   # Name of the dataset.
        raw_shape,              # Shape of the raw image data (NCHW).
        max_size    = None,     # max_size limit the size of the dataset. None = no limit. Applied before xflip.
        use_labels  = False,    # Enable conditioning labels? False = label dimension is zero.
        xflip       = False,    # Artificially double the size of the dataset via x-flips. Applied after max_size.
        random_seed = 0,        # Random seed to use when applying max_size.
    ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._use_labels = use_labels
        self._raw_labels = None
        self._label_shape = None

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            # np.random.RandomState(random_seed).shuffle(self._raw_idx)
            assert(0)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip.
        self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
        if xflip:
            assert(0)
            self._raw_idx = np.tile(self._raw_idx, 2)
            self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])

    def close(self): # to be overridden by subclass
        pass

    def _load_raw_image(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def _load_raw_labels(self, raw_idx): # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size

    def __getitem__(self, idx):
        image = self._load_raw_image(idx)
        label = self._load_raw_labels(idx)
        return image, label

    def get_label(self, idx):
        label = self._load_raw_labels(idx)
        return label

    def get_details(self, idx):
        return {}

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 4 # TCHW
        return self.image_shape[1]

    @property
    def resolution(self):
        assert len(self.image_shape) == 4 # TCHW
        assert self.image_shape[2] == self.image_shape[3]
        return self.image_shape[2]

    @property
    def label_shape(self):
        return list(self._raw_shape[1:])

    @property
    def label_dim(self):
        # assert len(self.label_shape) == 1
        return self.label_shape[0]

    @property
    def has_labels(self):
        return any(x != 0 for x in self.label_shape)

    @property
    def has_onehot_labels(self):
        return False

class LRW(Dataset):
    def __init__(self,
                vid_length = 16,
                path = None,                   # Path to directory or zip.
                resolution = None, # Ensure specific resolution, None = highest available.
                vocab_size=10,
                in_channels = 3,
                **super_kwargs,         # Additional arguments for the Dataset base class.
    ):
        self._path = path
        self._lrw = LRWDataset('your dataset path',
                        num_words=vocab_size, 
                        in_channels=in_channels, 
                        mode="train", 
                        augmentations=False, 
                        estimate_pose=False, 
                        seed=42, 
                        query=None,
                        crop_size=(256,256),
                        reshape_size=(256,256))
        self.vid_length = vid_length

        name = f'LRW_WORDS:{self._lrw.num_words}'
        raw_shape = [len(self._lrw)] + list(self._load_raw_image(0).shape)
        # import pdb
        # pdb.set_trace()
        if resolution is not None and (raw_shape[3] != resolution or raw_shape[4] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    def close(self):
        pass

    def __getstate__(self):
        return dict(super().__getstate__())

    def _load_raw_image(self, raw_idx):
        video = self._lrw[raw_idx]['frames'] # CTHW
        video = video.transpose(1,0)[:self.vid_length] # CTHW => TCHW
        return video

    def _load_raw_labels(self, raw_idx):
        labels = self._lrw[raw_idx]['label']
        return labels

#----------------------------------------------------------------------------

if __name__ == '__main__':
    dataset = LRW()
    loader = DataLoader(dataset, batch_size=2, shuffle=False)
    for item in loader:
        print(item[0].shape)
        print(item[1].shape)
        break
    print(dataset.name)
    print(dataset.image_shape)
    print(dataset.num_channels)
    print(dataset.resolution)
    print(dataset.label_shape)
    print(dataset.label_dim)
    print(dataset.has_labels)
    print(dataset.has_onehot_labels)
        