import copy

import torch
import numpy as np
from PIL import Image
import PIL
from torchvision import transforms as T
from pytorchvideo import transforms as vT
from data.transforms import video as vT2


class MultiScaleCropFlipColorJitter:
    def __init__(self,
                 num_frames=8,
                 crop=(224, 224),
                 color=(0.4, 0.4, 0.4, 0.1),
                 min_area=0.08,
                 augment=True
                 ):
        from collections.abc import Iterable
        if isinstance(crop, Iterable):
            crop = tuple(crop)
        self.crop = crop
        self.augment = augment
        self.num_frames = num_frames

        if augment:
            transforms = [
                vT2.RandomResizedCrop(crop, scale=(min_area, 1.)),
                vT2.RandomHorizontalFlip(),
                vT2.ColorJitter(*color),
            ]
        else:
            transforms = [
                vT2.Resize(int(crop[0]/0.875)),
                vT2.CenterCrop(crop),
            ]

        transforms += [
            vT2.ClipToTensor(),
            vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            vT.UniformTemporalSubsample(num_frames),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):
        return self.t(x)


class ResizeCropFlip:
    def __init__(self, num_frames=8, min_size=256, max_size=360, crop_size=224, augment=True):
        if augment:
            transforms = [
                vT2.RandomShortSideScale(min_size=min_size, max_size=max_size),
                vT2.RandomCrop(crop_size),
                vT2.RandomHorizontalFlip(),
            ]
        else:
            transforms = [
                vT2.Resize(min_size),
                vT2.CenterCrop(crop_size),
            ]
        transforms += [
            vT2.ClipToTensor(),
            vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            vT.UniformTemporalSubsample(num_frames),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):
        return self.t(x)


class VideoMAE_transform:
    def __init__(self, input_size, training=False):
        self.input_mean = [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
        self.input_std = [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD

        if training:
            transforms = [
                vT2.GroupMultiScaleCrop(input_size, [1, .875, .75]),
            ]
        else:
            transforms = [
                vT2.Resize(input_size),
            ]
        transforms += [
            vT2.ClipToTensor(),
            vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]

        self.t = T.Compose(transforms)

    def __call__(self, x):
        return self.t(x)


class VideoMAE_transform_buffer:
    def __init__(self, input_size, **kwargs):
        self.t = T.Compose(
            [
                vT2.GroupMultiScaleCrop(input_size, [1, .875, .75]),
                vT2.ClipToTensor(),
                vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def __call__(self, x):
        stack = []
        for t in x:
            stack.append(self.t(t))
        return torch.stack(stack).transpose(1,2)  # B x T x C x H x W


class CAV_Video_transform:

    def __init__(self, input_size, training=False):
        self.input_mean = [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
        self.input_std = [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD
        self.t = T.Compose(
            [
                T.Resize(input_size, interpolation=PIL.Image.BICUBIC),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.4850, 0.4560, 0.4060],
                    std=[0.2290, 0.2240, 0.2250])
            ]
        )

    def __call__(self, x):
        return self.t(x)


class CAV_Video_transform_buffer:

    def __init__(self, **kwargs):
        pass
    def __call__(self, x):
        return x





class VisualizeCrop:
    """
    For tensorboard visualization
    """

    def __init__(self, num_frames=8, input_size=224):
        transforms = [
            vT2.CenterCrop(input_size),
            vT2.ClipToTensor(),
            vT.UniformTemporalSubsample(num_frames),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):
        return self.t(x)


class VisualizeCropImage:

    def __init__(self, input_size=224):
        transforms = [
            T.Resize(input_size, interpolation=PIL.Image.BICUBIC),
            T.CenterCrop(input_size),
            T.ToTensor(),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):
        return self.t(x)




class Attention_vis_transform:
    def __init__(self, input_size, training=False):
        self.input_mean = [0.485, 0.456, 0.406]  # IMAGENET_DEFAULT_MEAN
        self.input_std = [0.229, 0.224, 0.225]  # IMAGENET_DEFAULT_STD

        transforms = [
            vT2.CenterCrop(input_size),
            vT2.ClipToTensor(),
            vT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):

        return self.t(x)



class Attention_vis_transform_high_resol:
    def __init__(self, input_size, training=False):
        transforms = [
            vT2.CenterCrop(input_size),
            vT2.ClipToTensor(),
        ]
        self.t = T.Compose(transforms)

    def __call__(self, x):

        return self.t(x)