""" This code is shared for review purposes only. Do not copy, reproduce, share,
publish, or use for any purpose except to review our submission. Please
delete after the review process. The authors plan to publish the code
deanonymized and with a proper license upon publication of the paper. """

import abc
import itertools
import numpy as np
from keras.preprocessing.image import apply_affine_transform
# The code is adapted from https://github.com/izikgo/AnomalyDetectionTransformations/blob/master/transformations.py

def get_transformer(type_trans):
    if type_trans == 'complicated':
        tr_x, tr_y = 8, 8
        transformer = Transformer(tr_x, tr_y)
        return transformer
    elif type_trans == 'medium':
        tr_x, tr_y = 8, 8
        transformer = MediumTransformer(tr_x, tr_y)
        return transformer
    elif type_trans == 'simple':
        transformer = SimpleTransformer()
        return transformer


def product_ind(*args):
    pools = [pool for pool in args]
    
    res = [[]]
    res_ind = [[]]
    for pool in pools:
        res = [x+[y] for x in res for y in pool]
        res_ind = [x+[y] for x in res_ind for y in range(len(pool))]
    return res, res_ind


class AffineTransformation(object):
    def __init__(self, flip, tx, ty, k_90_rotate):
        self.flip = flip
        self.tx = tx
        self.ty = ty
        self.k_90_rotate = k_90_rotate

    def __call__(self, x):
        res_x = x
        if self.flip:
            res_x = np.fliplr(res_x)
        if self.tx != 0 or self.ty != 0:
            res_x = apply_affine_transform(res_x,
            tx=self.tx, ty=self.ty, channel_axis=2, fill_mode='reflect')
        if self.k_90_rotate != 0:
            res_x = np.rot90(res_x, self.k_90_rotate)
        return res_x


class AbstractTransformer(abc.ABC):
    def __init__(self):
        self._transformation_list = None
        self._transformation_ind_list = None
        self._create_transformation_list()

    @property
    def n_transforms(self):
        return len(self._transformation_list)

    @property
    @abc.abstractmethod
    def n_transforms_per_task(self):
        pass

    @abc.abstractmethod
    def _create_transformation_list(self):
        return

    def transform_batch(self, x_batch, t_inds):
        assert len(x_batch) == len(t_inds)

        transformed_batch = x_batch.copy()
        transformed_batch_t = []
        for i, t_ind in enumerate(t_inds):
            transformed_batch[i] = self._transformation_list[t_ind](transformed_batch[i])
            transformed_batch_t.append(self._transformation_ind_list[t_ind])
        return transformed_batch, np.asarray(transformed_batch_t)


class Transformer(AbstractTransformer):
    def __init__(self, translation_x=8, translation_y=8):
        self.max_tx = translation_x
        self.max_ty = translation_y
        super().__init__()

    @property
    def n_transforms_per_task(self):
        return [2, 3, 3, 4]

    def _create_transformation_list(self):
        trans_params, transformation_ind_list = product_ind((False, True),
                                                            (0, -self.max_tx, self.max_tx),
                                                            (0, -self.max_ty, self.max_ty),
                                                            range(4))

        transformation_list = []
        for [is_flip, tx, ty, k_rotate] in trans_params:
            transformation = AffineTransformation(is_flip, tx, ty, k_rotate)
            transformation_list.append(transformation)

        self._transformation_list = transformation_list
        self._transformation_ind_list = transformation_ind_list
        return transformation_list


class MediumTransformer(AbstractTransformer):
    def __init__(self, translation_x=8, translation_y=8):
        self.max_tx = translation_x
        self.max_ty = translation_y
        super().__init__()

    @property
    def n_transforms_per_task(self):
        return [3, 3, 4]

    def _create_transformation_list(self):
        trans_params, transformation_ind_list = product_ind((0, -self.max_tx, self.max_tx),
                                                            (0, -self.max_ty, self.max_ty),
                                                            range(4))

        transformation_list = []
        for [tx, ty, k_rotate] in trans_params:
            transformation = AffineTransformation(False, tx, ty, k_rotate)
            transformation_list.append(transformation)

        self._transformation_list = transformation_list
        self._transformation_ind_list = transformation_ind_list
        return transformation_list


class SimpleTransformer(AbstractTransformer):
    @property
    def n_transforms_per_task(self):
        return [2, 4]

    def _create_transformation_list(self):
        trans_params, transformation_ind_list = product_ind((False, True),
                                                            range(4))

        transformation_list = []
        for [is_flip, k_rotate] in trans_params:
            transformation = AffineTransformation(is_flip, 0, 0, k_rotate)
            transformation_list.append(transformation)

        self._transformation_list = transformation_list
        self._transformation_ind_list = transformation_ind_list
        return transformation_list

