import torch
import numpy as np


class MeanShift(object):
    """Change mean of data.

    """

    def __init__(self, mean_shift, No_of_features, resample=True):
        if len(mean_shift[0]) != No_of_features:
            mean_shift[0] = float(mean_shift[0][0])*np.ones(No_of_features)
        if len(mean_shift[1]) != No_of_features:
            mean_shift[1] = float(mean_shift[1][0])*np.ones(No_of_features)
        self.mean_shift = torch.tensor(mean_shift)
        self.mean_shift = self.mean_shift.to(torch.float32)
        self.resample = resample

    def __call__(self, xy_data):
        """
        Args:
           TODO: description

        Returns:
           TODO: desciption
        """
        # Pick the shift of index of label y
        return xy_data[0] + self.mean_shift[xy_data[1]]



class Circle(object):
    """
    two concentric circles defined by their radii, e.g. [[0.3, 0.3], [0.5, 0.5]]
    for circle: radius[0] = radius[1] (radius used in x resp. y direction)
    """

    def __init__(self, radius, resample=True):
        self.radius = radius
        self.resample = resample

    def __call__(self, xy_data):
        theta = np.pi * np.random.uniform(-1, 1, 1)
        radius = self.radius[xy_data[1]]
        delta_x = np.cos(theta)*radius[0]
        delta_y = np.sin(theta)*radius[1]
        return [delta_x, delta_y] + xy_data[0]


class Moon(object):
    """
    two interleaving half circles
    take the following values for radii=1 and midpoints [0, 0] and [1, -0.5]: [[1, 1], [1, 1]]
    """

    def __init__(self, moon, resample=True):
        self.moon = moon
        self.resample = resample

    def __call__(self, xy_data):
        radius = self.moon[xy_data[1]]
        theta = np.pi * np.random.uniform((xy_data[1]-1), xy_data[1], 1)
        delta_x = xy_data[1]*radius[0] + np.cos(theta)*radius[0]
        delta_y = -0.5*xy_data[1]*radius[1] + np.sin(theta)*radius[1]
        return [delta_x, delta_y] + xy_data[0]


class Triangle(object):
    """
    two V-shaped data sets defined by their midpoint, e.g. [[0, 0], [0, -0.5]]
    creates data with y values up to 0.5
    """

    def __init__(self, midpoint, resample=True):
        self.midpoint = midpoint
        self.resample = resample

    def __call__(self, xy_data):
        midpoint = self.midpoint[xy_data[1]]
        delta_x = np.random.uniform((midpoint[0]-0.5+midpoint[1]), (midpoint[0]+0.5-midpoint[1]))
        delta_y = np.abs((delta_x-midpoint[0]))+midpoint[1]
        return [delta_x, delta_y] + xy_data[0]


class Scaling(object):
    """Change mean of data.

    """

    def __init__(self, scaling, No_of_features=2, resample=True):
        if len(scaling) != No_of_features:
            scaling = float(scaling[0])*np.ones(No_of_features)
        self.scaling = torch.tensor(scaling)
        self.scaling = self.scaling.to(torch.float32)
        self.resample = resample

    def __call__(self, xy_data):
        """
        Args:
           TODO: description

        Returns:
           TODO: desciption
        """
        return self.scaling * xy_data[:-1], int(xy_data[-1])
