# augmentation definition

import torch
import numpy as np


class Augmentation(object):
    def __init__(self):
        pass

    def scale(self, X, scale_method="boundary"):
        """
        The Scale function takes in a batch of points and scales them to be between 0 and 1.
        It does this by translating the points so that the minimum x-value is at 0,
        and then dividing all x-values by the maximum value. It does this for both dimensions.

        :param X: Store the data and the scale_method parameter is used to determine how to scale it
        :param scale_method: Decide whether to scale the data based on the boundary of all points or just
        :return: The scaled x and the ratio
        :doc-author: Trelent
        """
        B = X.size(0)
        SIZE = X.size(1)
        X = X - torch.reshape(torch.min(X, 1).values, (B, 1, 2)).repeat(1, SIZE, 1)  # translate
        ratio_x = torch.reshape(torch.max(X[:, :, 0], 1).values - torch.min(X[:, :, 0], 1).values, (-1, 1))
        ratio_y = torch.reshape(torch.max(X[:, :, 1], 1).values - torch.min(X[:, :, 1], 1).values, (-1, 1))
        ratio = torch.max(torch.cat((ratio_x, ratio_y), 1), 1).values
        X = X / (torch.reshape(ratio, (B, 1, 1)).repeat(1, SIZE, 2))
        return X, ratio

    def rotate(self, X, scale_method="boundary"):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        B = X.size(0)
        SIZE = X.size(1)
        Theta = torch.rand((B, 1), device=device) * 2 * np.pi
        Theta = Theta.repeat(1, SIZE)
        tmp1 = torch.reshape(X[:, :, 0] * torch.cos(Theta) - X[:, :, 1] * torch.sin(Theta), (B, SIZE, 1))
        tmp2 = torch.reshape(X[:, :, 0] * torch.sin(Theta) + X[:, :, 1] * torch.cos(Theta), (B, SIZE, 1))
        X_out = torch.cat((tmp1, tmp2), dim=2)
        X_out += 10
        X_out, ratio = self.scale(X_out)
        return X_out, ratio

    def reflect(self, X, scale_method="boundary"):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        B = X.size(0)
        SIZE = X.size(1)
        Theta = torch.rand((B, 1), device=device) * 2 * np.pi
        Theta = Theta.repeat(1, SIZE)
        tmp1 = torch.reshape(X[:, :, 0] * torch.cos(2 * Theta) + X[:, :, 1] * torch.sin(2 * Theta), (B, SIZE, 1))
        tmp2 = torch.reshape(X[:, :, 0] * torch.sin(2 * Theta) - X[:, :, 1] * torch.cos(2 * Theta), (B, SIZE, 1))
        X_out = torch.cat((tmp1, tmp2), dim=2)
        X_out += 10
        X_out, ratio = self.scale(X_out)
        return X_out, ratio

    @staticmethod
    def noise(X, noise):
        return X + (torch.rand(X.size(), device=X.device) - 0.5) * noise

    def mixture(self, X, scale_method="boundary", noise=1e-5):
        # 50% random rotate, 50% random reflect
        # this is equal to random rotate + 50% fix reflect, 50% unchanged
        if np.random.rand() < 0.5:
            X_out, ratio = self.rotate(X)
        else:
            X_out, ratio = self.reflect(X)

        X_out = self.noise(X_out, noise)
        return X_out, ratio

    # def augment(self, aug, x):
    #     if aug == 'mixture':
    #         x_out, _ = self.mixture(x)
    #     else:
    #         raise ValueError()
    #     return x_out

    def aug_for_train(self, aug, x, repeat):
        x_clone = x.clone()
        if aug == 'mixture':
            x_out, _ = self.mixture(x)
        else:
            raise ValueError()
        x_out[0::repeat] = x_clone[0::repeat]
        return x_out
