import numpy as np
from gym.utils import seeding

####################################################################################################################################
# REWARD WRAPPERS
####################################################################################################################################

def compute_simple_reward(goal, state, bounds):
    """ Compute a simple error based reward """
    error = goal - state
    return -np.sum(np.abs(error))

def compute_square_reward(goal, state, bounds):
    """ Compute a simple square of error based reward """
    error = goal - state
    return -np.sum((error)**2)

def compute_max1_positive_reward(goal, state, bounds):
    error = np.clip(abs(goal - state)/0.05, 1., np.inf)
    bound_dif = np.clip(abs(bounds[1] - bounds[0])/0.05, 1., np.inf)

    log_error = np.log(error)
    log_bounds = np.log(bound_dif)

    rel_rew = 1. - log_error/log_bounds
    return np.prod(rel_rew)

def compute_max1_scaled_mse_reward(goal, state, bounds):
    error = goal - state
    rmse = np.sqrt(np.mean(error**2))
    min_bound = bounds[0]
    max_bound = bounds[1]
    max_rmse = np.sqrt(np.mean((max_bound - min_bound)**2))
    return 1. - rmse/max_rmse

def compute_max1_scaled_mae_reward(goal, state, bounds):
    error = goal - state
    mae = np.mean(abs(error))
    min_bound = bounds[0]
    max_bound = bounds[1]
    max_mae = np.mean(abs(max_bound - min_bound))
    return 1. - mae/max_mae
####################################################################################################################################


####################################################################################################################################
# BASE CLASS
####################################################################################################################################

class ShapeGen(object):
    """
    Shape Gen class
    - Compnents should include:
     .get_state(phase): returns a 2D np array [x,y] appropriate for the input phase
     .get_bounds(): get the min and max components of state space as a pair of np arrays [xmin, ymin], [xmax, ymax]
     .sample_demo_phase(): get a phase corresponding to a region with demonstration (if one exists) - else returns None
    """

    reward_fn = compute_simple_reward

    def get_state(self, phase):
        """
        Given a phase, return (x,y) coordinates of the corresponding point on the path
        return a 2-d numpy array and a boolean to indicate if state comes from demo
        """

        raise NotImplementedError

    def get_true_state(self, phase):
        return self.get_state(phase)

    def get_RL_reward(self, phase, state):
        """
        Get the RL reward signal for this environment
        Input: Phase in radians (float)
        Output: scalar reward (float)
        """
        current_desired, _ = self.get_state(phase)
        return self.reward_fn(current_desired, state, self.get_bounds())

    def get_goal_and_reward(self, phase, state):
        """
        Combination of get_state and get_RL_reward
        Input: Phase in readians (float)
        Output: Goals state (np array [x,y] coordinates), scalar reward (float)
        """
        current_desired, demo = self.get_state(phase)
        return current_desired, self.reward_fn(current_desired, state, self.get_bounds()), demo      

    def get_bounds(self):
        """
        Return the bounds of the state space as a pair of np arrays [xmin, ymin], [xmax, ymax]
        """
        raise NotImplementedError

    def sample_demo_phase(self):
        """
        Get a phase corresponding to a region with demonstration (if one exists) 
        Default returns None unless otherwise implemented
        """
        return None

    def set_rand_seed(self, seed):
        """
        Set random seed for shapes requiring random behavior
        """
        raise NotImplementedError

####################################################################################################################################


####################################################################################################################################
# BASIC SHAPES
####################################################################################################################################

class CirclePath(ShapeGen):
    """
    Path generator for a simple circular path
    """
    def __init__(self, radius=5., reward_fn=compute_simple_reward, **kwargs):
        self.radius = radius
        self.reward_fn = reward_fn

    def get_state(self, phase):
        """
        Returns the position along a circle for input phase in radians
        Input: Phase in radians (float)
        Output: np array [x, y] coordinates
        """
        return self.radius*np.array([np.cos(phase), np.sin(phase)]), False

    def get_bounds(self):
        """
        Returns the bounds of the state sapce
        Outputs: np arrays [xmin, ymin], [xmax, ymax]
        """
        return self.radius*np.array([-1.25,-1.25]), self.radius*np.array([1.25,1.25])

    def set_rand_seed(self, seed):
        """
        Set random seed for shapes requiring random behavior
        """
        pass


class SquarePath(ShapeGen):
    """
    Path generator for a simple square path
    """
    def __init__(self, side_len=10., reward_fn=compute_simple_reward, **kwargs):
        self.side_len = side_len
        self.reward_fn = reward_fn

    def get_state(self, phase):
        """
        Returns the position along a circle for input phase in radians
        Input: Phase in radians (float)
        Output: np array [x, y] coordinates
        """

        # 0 - 90 degrees -> top
        # 90 - 180 degrees -> right
        # 180 - 270 degrees -> bottom
        # 270 - 360 degrees -> left

        if phase > 2*np.pi:
            phase %= 2*np.pi

        if phase < 0:
            phase += 2*np.pi

        if (phase < np.pi/2) and (phase >= 0):
            state_x = 2 * phase/np.pi
            state = self.side_len * np.array([state_x, 1.])
        elif (phase < np.pi) and (phase >= np.pi/2):
            state_y = 1. - 2 * (phase - np.pi/2)/np.pi
            state = self.side_len * np.array([1., state_y])
        elif (phase < 1.5 * np.pi) and (phase >= np.pi):
            state_x = 1. - 2 * (phase - np.pi)/np.pi
            state = self.side_len * np.array([state_x, 0.])
        else:
            state_y = 2 * (phase - 1.5*np.pi)/np.pi 
            state = self.side_len * np.array([0., state_y])

        state -= self.side_len * np.array([0.5,0.5])

        return state, False

    def get_bounds(self):
        """
        Returns the bounds of the state sapce
        Outputs: np arrays [xmin, ymin], [xmax, ymax]
        """
        return self.side_len*np.array([-.625,-.625]), self.side_len*np.array([.625,.625])

    def set_rand_seed(self, seed):
        """
        Set random seed for shapes requiring random behavior
        """
        pass


class TrianglePath(ShapeGen):
    """
    Path generator for a simple triangle path
    """
    def __init__(self, height=5., reward_fn=compute_simple_reward, **kwargs):
        self.height = height
        self.reward_fn = reward_fn

    def get_state(self, phase):
        """
        Returns the position along a circle for input phase in radians
        Input: Phase in radians (float)
        Output: np array [x, y] coordinates
        """

        # 0 - 120 degrees -> right
        # 120 - 240 degrees -> base
        # 240 - 360 degrees -> left

        if phase > 2*np.pi:
            phase %= 2*np.pi

        if (phase < 2*np.pi/3) and (phase >= 0):
            state_x = 0.75 * phase/np.pi
            state_y = 1. - 1.5 * phase/np.pi
            state = self.height * np.array([state_x, state_y])
        elif (phase < 4*np.pi/3) and (phase >= 2*np.pi/3):
            state_x = 0.5 - 1.5 * (phase - 2*np.pi/3)/np.pi
            state = self.height * np.array([state_x, 0.])
        else:
            state_x = -0.5 + 0.75 * (phase - 4*np.pi/3)/np.pi
            state_y = 1.5 * (phase - 4*np.pi/3)/np.pi
            state = self.height * np.array([state_x, state_y])

        state -= self.height * np.array([0, 0.5])

        return state, False

    def get_bounds(self):
        """
        Returns the bounds of the state sapce
        Outputs: np arrays [xmin, ymin], [xmax, ymax]
        """
        return self.height*np.array([-.625,-.625]), self.height*np.array([.625,.625])

    def set_rand_seed(self, seed):
        """
        Set random seed for shapes requiring random behavior
        """
        pass

####################################################################################################################################

####################################################################################################################################
