import numpy as np
import cv2 as cv

from albumentations import ToFloat, FromFloat, Compose, OneOf, \
    GaussNoise, MotionBlur, ShiftScaleRotate, CoarseDropout

class TemporalAugmentation:

    def __init__(self, max_num_frames=32, channel_last=False, p_scale=0.95, p_geo=0.95, max_value=255.0):
        self.channel_last = channel_last
        additional_targets = {f'img_{i}': 'image' for i in range(max_num_frames)}
        
        scale_trans = [GaussNoise(p=1.0, var_limit=(0.001, 0.005)),
                       MotionBlur(p=1.0, blur_limit=(3, 7)), 
                       CoarseDropout(p=1.0, max_holes=4, max_height=8, max_width=8, fill_value=0.0)]
        scale_trans = OneOf(scale_trans, p=1.0)

        geo_trans = [ShiftScaleRotate(shift_limit=(-0.0625, 0.0625), 
                                     scale_limit=(-0.2, 0.2), rotate_limit=0,
                                     border_mode=cv.BORDER_CONSTANT, p=1.0), 
                     ShiftScaleRotate(shift_limit=0, scale_limit=0, rotate_limit=(-90, 90), 
                                      border_mode=cv.BORDER_CONSTANT)]
        geo_trans = OneOf(geo_trans, p=1.0)

        tofloat = ToFloat(max_value=max_value)
        fromfloat = FromFloat(max_value=max_value)
        
        self.scale_trans = Compose([tofloat, scale_trans, fromfloat], p=p_scale, additional_targets=additional_targets)
        self.geo_trans = Compose([tofloat, geo_trans, fromfloat], p=p_geo, additional_targets=additional_targets)

    def apply(self, frames, transform_fn, framewise=False):
        arr = frames.copy()
        if not self.channel_last:
            arr = np.transpose(arr, [0, 2, 3, 1])
        if framewise:
            transformed = [transform_fn(image=x)['image'] for x in arr]
        else:
            images = {'image': arr[0]}
            images.update({f'img_{i}': arr[i] for i in range(1, len(arr))})
            transformed = list(transform_fn(**images).values())
        out = np.stack(transformed)
        if not self.channel_last:
            out = np.transpose(out, [0, 3, 1, 2])
        return out

    def transform(self, frames):
        x = self.apply(frames, self.scale_trans, framewise=True)
        x = self.apply(x, self.geo_trans, framewise=False)
        return x

    def batch_transform(self, frames):
        out_frames = []
        for i in range(len(frames)):
            out = self.transform(frames[i])
            out_frames.append(out)
        out_frames = np.stack(out_frames, axis=0)
        return out_frames