""" Most of the code here is ported DIRECTLY from the source code of torchvision.transforms from PyTorch 1.4.0. Modifications are made as necessary to adapt the transformations to
    the temporal domain, but the logic of the transformations themselves operates mostly
    based on torchvision.transforms.functional, which works with PIL Images mostly. At
    this time a stable OpenCV backend for torchvision isn't officially part of PyTorch.

    General API:

    each class inherits from 'object', and must define __call__.

    __call__: apply transform to video sequence. Useful when you have access to the
        entire sequence.
    get_transform: returns a transformation function randomly selected from the transformation
        family specified that can be applied image-by-image. Useful when you must load frame
        by frame but want to apply the same function. Deprecated.

"""

from torchvision.transforms import functional as F
from torchvision.transforms.transforms import _pil_interpolation_to_str, Lambda
from skimage.filters import gaussian
from scipy.ndimage import zoom

import torch
import random
import math
from PIL import Image, ImageFilter
import numpy as np
import numbers
from itertools import chain
import os
import cv2
import datasets.video_corrupt as video_corrupt
import utils
import warnings


import shlex
import random
import string
import subprocess
from collections import OrderedDict
import time

np.random.seed(42)


class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    """

    def __init__(self, transforms, params=None):
        self.transforms = transforms
        if not len(transforms):
            raise ValueError("Cannot compose an empty set of transformations!")
        if params:
            self.params = params
        else:
            self.params = [{} for t in self.transforms]

    def get_transform(self, *args):
        transforms = []
        for t in self.transforms:
            transforms.append(t.get_transform(*args))

        def transform_(img):
            for t in transforms:
                img = t(img)
            return img
        return transform_

    def __call__(self, video, time_dim=1):
        for fn, params in zip(self.transforms, self.params):
            if hasattr(params, '__len__'):
                video = fn(video, *params)
            else:
                video = fn(video, params)
        return video

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class VideoTransform(object):
    def __init__(self, flip_prob=0.5):
        self.nc = 3 # deal with this magic number later, preferably when they fix torchvision
        self.flip_prob = flip_prob

    """
        Standardization/type-checking/value-checking code should go in the superclass.
        Returns everything in the required format: c, t, [w, h]
    """
    def __call__(self, video, time_dim):
        if type(video) != torch.Tensor:
            raise ValueError("Video tensor must be of type torch.Tensor but got {} instead".format(type(video)))
        if len(video.size()) < 3:
            raise  ValueError("Video tensor must have at least 3 dimensions (4 for multi-channel): one temporal dimension and two spatial dimensions, with an optional channel dimension. Got a torch.Tensor of shape {}".format(video.size()))
        #tdim = min(chain(range(0, time_dim), range(time_dim+1, video.ndimension())))
        if video.size(0) != self.nc:
            try:
                cdim = video.size().index(self.nc)
                dim_order = (cdim,) + tuple(chain(range(0, cdim), range(cdim + 1, video.ndimension())))
                video = video.permute(dim_order)
            except ValueError as e:
                raise ValueError("Passed in {} channels but image has shape {}. Original error follows --".format(self.nc, video.size()), e)
        return video

class MultiScaleOrTenCrop(VideoTransform):
    def __init__(self, target_width, target_height, interpolation=Image.BICUBIC):
        super(MultiScaleOrTenCrop, self).__init__()
        self.multi_scale_random_crop = MultiScaleRandomCrop(target_width, target_height, key_points, interpolation)
        self.ten_crop = HaraTransform(target_width, target_height, key_points)

    def __call__(self, video, time_dim=1, meta=None):
        if random.random() < 0.5:
            return self.multi_scale_random_crop(video, time_dim, path)
        else:
            return self.ten_crop(video, time_dim, path)

class MultiScaleRandomCrop(VideoTransform):
    def __init__(self, target_width, target_height, flip_prob=0.5, interpolation=Image.BICUBIC):
        super(MultiScaleRandomCrop, self).__init__(flip_prob)
        self.scales = [1., 1. / math.pow(2., 0.25), 1. / math.sqrt(2), 1. / math.pow(2., 0.72), 1. / 2.]
        self.width = target_width
        self.height = target_height
        self.interpolation = interpolation

    def __call__(self, video, time_dim=1, meta=None):
        video = super(MultiScaleRandomCrop, self).__call__(video, time_dim)
        seq = []
        t = video.size(1)
        flip = (random.random() < self.flip_prob)
        for n, img in enumerate(torch.split(video, 1, dim=1)): 

            if n % math.ceil(t / self.key_points) == 0:
                scale = random.choice(self.scales)
                size = video.size()[-2:]
                new_size = list(size)
                new_size[np.argmax(new_size)] = int(new_size[np.argmax(new_size)] * scale)
                random.shuffle(new_size)
                i = random.randint(0, self.height - new_size[0])
                j = random.randint(0, self.width - new_size[1])

            img = F.to_pil_image(img.squeeze(1))
            img = F.resized_crop(img, i, j, new_size[0], new_size[1], (self.width, self.height), self.interpolation)
            if flip: img = F.hflip(img)
            img = F.to_tensor(img)
            seq.append(img)
        return torch.stack(seq, dim=time_dim)



class HaraTransform(VideoTransform):
    def __init__(self, target_width, target_height):
        super(HaraTransform, self).__init__()
        self.scales = [1., 1. / math.pow(2., 0.25), 1. / math.sqrt(2), 1. / math.pow(2., 0.72), 1. / 2.]
        self.width = target_width
        self.height = target_height

    def get_transform(self, width, height):
        """
            Parameters to fix:
            - which crop
            - whether to flip
        """
        crop_idx = random.randint(0, 4)
        scale = random.choice(self.scales)
        def transform_(img):
            if not isinstance(img, Image.Image): img = F.to_pil_image(img)
            size = img.size
            new_size = list(size)
            new_size[np.argmax(new_size)] = min(size) * scale
            crop = F.five_crop(img, new_size)[crop_idx]
            crop = F.resize(crop, (self.width, self.height))
            crop = F.to_tensor(crop) 
            return crop
        
        if random.random() > 0.5:
            return Compose([transform_, Lambda(lambda x: x.permute(1, 2, 0))])
        else:
            # why so many conversions, you ask? Inflexible APIs, incompatible types/channel orders, and this is what you get!
            return Compose([transform_, F.to_pil_image, F.hflip, F.to_tensor, Lambda(lambda x: x.permute(1, 2, 0))])


    def __call__(self, video, time_dim=1, meta=None):
        video = super(HaraTransform, self).__call__(video, time_dim)
        seq = []
        t = video.size(1)
        flip = (random.random() < self.flip_prob)
        for n, img in enumerate(torch.split(video, 1, dim=1)): 

            if n % math.ceil(t / self.key_points) == 0:
                scale = random.choice(self.scales)
                size = video.size()[-2:]
                new_size = list(size)
                new_size[np.argmax(size)] = new_size[np.argmax(new_size)] * scale
                random.shuffle(new_size)
                crop_idx = random.randint(0, 4)

            img = F.to_pil_image(img.squeeze(1))
            crop = F.five_crop(img, new_size)[crop_idx]
            crop = F.resize(crop, (self.width, self.height))
            if flip: crop = F.hflip(crop)
            crop = F.to_tensor(crop)
            seq.append(crop)
        return torch.stack(seq, dim=time_dim)

    def __repr__(self):
        format_string = self.__class__.__name__ + '(FiveCrop -> SelectCrop -> Resize -> (Flip, p=0.5))'
        return format_string


class IFrameWhack(VideoTransform):
    def __init__(self, corrupt_mode='random', corrupt_prob=1., cache_size=4, temp_dir='video-augmentation-experiments/src/ffmpeg_cache'):
        super(IFrameWhack, self).__init__()
        self.corrupt_mode = corrupt_mode
        self.corrupt_prob = corrupt_prob
        self.temp_file = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(16))
        self.temp_dir = temp_dir
        self.cache = OrderedDict()
        self.cache_size = cache_size

    def get_corrupt_path(self, path):
        dirs = path.split("/")
        new_path = "/" + os.path.join(*dirs[:-3], "corrupt", dirs[-2], self.temp_file + "_CORRUPTED_" + dirs[-1])
        return new_path

    def make_corrupt_copy(self, old_path, path):
        if not os.path.exists(os.path.dirname(path)):
            try:
                 os.makedirs(os.path.dirname(path))
            except OSError as exc: # Guard against race condition
                 if exc.errno != errno.EEXIST:
                     raise
        path = path.replace('.avi', '.mp4') 
        tmp = os.path.join(self.temp_dir, self.temp_file + ".mp4")
        if os.path.isfile(path): 
            return path
        cmd = '/usr/bin/ffmpeg -loglevel quiet -i {} -vcodec libx264 -strict -2 {}'.format(shlex.quote(old_path), tmp)
        os.system(cmd)
        assert os.path.isfile(tmp), "failed on command " + cmd
        old_path = tmp
        video_corrupt.whack_mpeg_iframes(old_path, path, mode=self.corrupt_mode, p=self.corrupt_prob)
        os.unlink(tmp)
 
        return path

    def __call__(self, video, time_dim=1, meta=None):
        #video = super(FileCorrupt, self).__call__(video, time_dim)
        if self.corrupt_prob == 0: return video # bypass
        _, t, w, h = video.size()
        path, start = meta
        start = int(start)

        old_path = path
        path = self.get_corrupt_path(old_path)
        path = self.make_corrupt_copy(old_path, path)
        self.cache[path] = True
        self.cache.move_to_end(path) 
        if len(self.cache) > self.cache_size: 
            old_file, _ = self.cache.popitem(last=False) 
            os.unlink(old_file)
        assert os.path.isfile(path), "No such file " + path

        # silence stderr to avoid 109238741028347 ffmpeg error messages
        with utils.stderr_suppress():
            cap = utils.safe_capture(path)
            cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
            seq = []
            #print(path, cap.get(cv2.CAP_PROP_FRAME_COUNT))
            for i in range(t):
                ret, img = cap.read()
                if ret:
                    img = cv2.resize(img, (w, h))
                else: # loop
                    tries = 0
                    while not ret and tries < 10:
                        cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
                        #print(cap.get(cv2.CAP_PROP_POS_FRAMES))
                        tries += 1
                        ret, img = cap.read()
                        try:
                            img = cv2.resize(img, (w, h))
                        except Exception as e:
                            print(start)
                            warnings.warn("On " + path + "; frame " + str(cap.get(cv2.CAP_PROP_POS_FRAMES)) + ": " + str(e))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
                img = F.to_tensor(img) 
                seq.append(img) 
            cap.release()
        seq = torch.stack(seq, axis=0)
        seq = seq.permute(1, 0, 2, 3) #CTWH
        return seq

    def __del__(self):
        for fname in self.cache:
            os.unlink(fname)

    def __repr__(self):
        s = self.__class__.__name__
        s += "(corrupt_mode={}".format(self.corrupt_mode)
        s += ", corrupt_prob={}".format(self.corrupt_prob)
        s += ", temp_dir={}".format(self.temp_dir)
        s += ", temp_file={}".format(self.temp_file)
        s += ")"
        return s

class RandomBitCorruption(VideoTransform):
    def __init__(self, cache_size=10, temp_dir='video-augmentation-experiments/src/ffmpeg_cache', contiguous_probs=[0.01, 0.1, 0.25, 0.5, 0.75, 0.9], random_probs=[1e-8, 5e-8, 1e-7, 5e-7, 1e-6, 5e-6, 1e-5, 5e-5, 1e-4], network_probs=[1e-4, 1e-3, 0.01, 0.1, 0.2], bit_corrupt_levels=None):
        self.temp_dir = temp_dir
        self.cache_size = cache_size
        self.contiguous_probs = contiguous_probs
        self.random_probs = random_probs
        self.network_probs = network_probs
        self.corruptions = []
        self.bit_corrupt_levels = bit_corrupt_levels
        if bit_corrupt_levels == 'low':
            self.contiguous_probs = [0.01, 0.1]
            self.random_probs = [1e-8, 5e-8, 1e-7, 5e-7, 1e-6]
            self.network_probs = [1e-4, 1e-3, 1e-2]
        elif bit_corrupt_levels == 'high':
            self.contiguous_probs = [0.25, 0.5, 0.75, 0.9]
            self.random_probs = [5e-6, 1e-5, 5e-5, 1e-4]
            self.network_probs = [0.1, 0.2]
        for p in self.contiguous_probs:
            self.corruptions.append(BitCorrupt(0, corrupt_mode='contiguous', corrupt_prob=p, temp_dir=self.temp_dir, cache_size=self.cache_size))
        for p in self.random_probs:
            self.corruptions.append(BitCorrupt(0, corrupt_mode='random', corrupt_prob=p, temp_dir=self.temp_dir, cache_size=self.cache_size))
        for p in self.network_probs:
            self.corruptions.append(NetworkCorrupt(0, corrupt_prob=p))

    def __call__(self, video, time_dim=1, meta=None):
        corruptor = random.choice(self.corruptions)
        corrupted_vid = torch.Tensor(0)
        attempts = 0
        while len(corrupted_vid) == 0 and attempts < 20: # ensure augmented corrupted video comes out readable
            if hasattr(corruptor, 'cache') and len(corruptor.cache):
                corruptor.cache.popitem(last=True)
            corrupted_vid = corruptor(video, time_dim=time_dim, meta=meta)
            attempts += 1
        if not len(corrupted_vid): return video
        return corrupted_vid

    def __repr__(self):
        s = self.__class__.__name__
        s += "(contiguous_probs={},".format(self.contiguous_probs)
        s += " random_probs={},".format(self.random_probs)
        s += " network_probs={}".format(self.network_probs)
        s += ")"
        return s

 
class NetworkCorrupt(VideoTransform):
    def __init__(self, corrupt_prob=0.1, max_frames=16):
        super(NetworkCorrupt, self).__init__()
        self.corrupt_prob = corrupt_prob
        self.max_frames = max_frames

    def __call__(self, video, time_dim=1, meta=None):
        path, start = meta
        dirs = path.split('/')
        corruption_name = 'loss_uplink_{}_ver{}'.format(self.corrupt_prob, random.randrange(0, 6))
        target_path = os.path.join('network_corruptions/', corruption_name, *dirs[-3:])
        if not os.path.isfile(target_path):
            return video
        c, t, w, h = video.size()
        with utils.stderr_suppress():
            cap = utils.safe_capture(path)
            seq = []
            # here's the dangerous part --  corrupted frames means that CAP_PROP_FRAME_COUNT isn't necessarily the # of readable frames
            for i in range(t):
                ret, img = cap.read()
                if ret:
                    img = cv2.resize(img, (w, h))
                else: # loop
                    cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
                    ret, img = cap.read()
                    try:
                        img = cv2.resize(img, (w, h))
                    except Exception as e:
                        return torch.Tensor(0)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = F.to_tensor(img)
                seq.append(img)
            cap.release()
        seq = seq[start:start+self.max_frames]
        if len(seq) == 0: 
            return video # video got shorter, so we overshot
        while len(seq) < self.max_frames:
            seq = (seq + seq)[:self.max_frames] 
        # we need to adjust the size in the transform fn bc transforms should be size-invariant
        seq = torch.stack(seq, axis=0)
        seq = seq.permute(1, 0, 2, 3) #CTWH
        return seq

    def __repr__(self):
        s = self.__class__.__name__
        s += "(corrupt_prob={}".format(self.corrupt_prob)
        s += ")"
        return s


class BitCorrupt(VideoTransform):
    def __init__(self, corrupt_mode='random', corrupt_prob=1., cache_size=10, temp_dir='video-augmentation-experiments/src/ffmpeg_cache'):
        super(BitCorrupt, self).__init__()
        self.corrupt_mode = corrupt_mode
        self.corrupt_prob = corrupt_prob
        self.temp_dir = temp_dir
        random.seed(time.time() + os.getpid())
        self.random_fname = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16))
        self.cache = OrderedDict()
        self.cache_size = cache_size

    def get_corrupt_path(self, path):
        dirs = path.split("/")
        new_path = "/" + os.path.join(*dirs[:-3], "corrupt", dirs[-2], self.random_fname + "_CORRUPTED_" + dirs[-1])
        return new_path

    def make_corrupt_copy(self, old_path, path):
        # TODO: mutex here
        if not os.path.exists(os.path.dirname(path)):
            try:
                 os.makedirs(os.path.dirname(path))
            except OSError as exc: # Guard against race condition
                 if exc.errno != errno.EEXIST:
                     raise warnings.warn(exc)
        path = path.replace('.avi', '.mp4') 
        if os.path.isfile(path) and path in self.cache:
            #assert path in self.cache, path + " not in cache: " + str(self.cache) 
            return path, self.cache[path]
        random.seed(time.time() + os.getpid())
        rand_string = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(16))
        tmp = os.path.join(self.temp_dir, rand_string + ".mp4")
        utils.gen_temp(old_path, tmp)
        corruption_info = video_corrupt.flip(tmp, path, 'h264', mode=self.corrupt_mode, p=self.corrupt_prob)
        os.unlink(tmp)
        return path, corruption_info

    def check_unreadable(self, path):
        assert path in self.cache, self.cache
        if 'unreadable' in self.cache[path]:  return self.cache[path]['unreadable']
        child = subprocess.Popen(['ffprobe', '-loglevel', 'quiet', path])
        child.communicate()
        if child.returncode == 0: 
            if video_corrupt.get_codec(path) == 'h264': # NO! 
                self.cache[path]['unreadable'] = False
                return False
        self.cache[path]['unreadable'] = True
        return True

    def __call__(self, video, time_dim=1, meta=None):
        #video = super(FileCorrupt, self).__call__(video, time_dim)
        path, start = meta
        if self.corrupt_prob == 0:
            corruption_info = {'corrupted_iframe':False, 'locations': []}
            self.cache[path] = corruption_info
            self.cache.move_to_end(path)
            if len(self.cache) > self.cache_size:
                old_file, _ = self.cache.popitem(last=False)
            return video
        _, t, w, h = video.size()
        start = int(start)
        old_path = path
        path = self.get_corrupt_path(old_path)
        path, corruption_info = self.make_corrupt_copy(old_path, path)
        self.cache[path] = corruption_info
        self.cache.move_to_end(path)
        if len(self.cache) > self.cache_size:
            old_file, _ = self.cache.popitem(last=False)
            os.unlink(old_file)
        if self.check_unreadable(path):
            return torch.Tensor(0)
        assert os.path.isfile(path), "No such file " + path
        # silence stderr to avoid 109238741028347 ffmpeg error messages 
        with utils.stderr_suppress():
            cap = utils.safe_capture(path)
            cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
            seq = []
            # here's the dangerous part --  corrupted frames means that CAP_PROP_FRAME_COUNT isn't necessarily the # of readable frames
            for i in range(t):
                ret, img = cap.read()
                if ret:
                    img = cv2.resize(img, (w, h))
                else: # loop
                    cap.set(cv2.CAP_PROP_POS_FRAMES, float(start))
                    ret, img = cap.read()
                    try:
                        img = cv2.resize(img, (w, h))
                    except Exception: # if this still doesn't work, this means that this frame is unrecoverable. 
                        #warnings.warn("On " + path + "; frame " + str(cap.get(cv2.CAP_PROP_POS_FRAMES)) + ": " + str(e))
                        return torch.Tensor(0)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = F.to_tensor(img)
                seq.append(img)
            cap.release()
        seq = torch.stack(seq, axis=0)
        seq = seq.permute(1, 0, 2, 3) #CTWH
        return seq


    def __repr__(self):
        s = self.__class__.__name__
        s += "(corrupt_mode={}".format(self.corrupt_mode)
        s += ", corrupt_prob={}".format(self.corrupt_prob)
        s += ", temp_dir={}".format(self.temp_dir)
        s += ")"
        return s

class CenterCrop(VideoTransform):
    def __init__(self, size, p=0.5):
        super(CenterCrop, self).__init__()
        if isinstance(size, tuple):
            self.size = size
        else:
            self.size = (size, size)
        self.p = p

    def __call__(self, vid, time_dim=1, meta=None):
        vid = super(CenterCrop, self).__call__(vid, time_dim)
        seq = []
        flip = (random.random() < self.p)
        for img in torch.split(vid, 1, dim=1):
            img = F.to_pil_image(img.squeeze(1))
            if flip: img = F.hflip(img)
            img = F.resize(img, self.size)
            img = F.center_crop(img, self.size)
            img = F.to_tensor(img)
            seq.append(img)
        seq = torch.stack(seq, dim=time_dim)
        return seq

