"""
Weird gain env: an environment that allows for open- and closed-loop controllers
to be tested in an environment that is easy to plot and understand.

Viraj Mehta, 2022
"""

import gym
from gym import spaces
import numpy as np

GOAL = np.array([6, 9])


class NavigationEnv(gym.Env):
    def __init__(self, uniform_start=False, easy=False, shifted=False):
        self.observation_space = spaces.Box(
            low=np.array([-10, -10]), high=np.array([10, 10])
        )
        self.action_space = spaces.Box(low=-np.ones(2), high=np.ones(2))
        self.x = None
        self.start_space_low = np.array([-10, -6])
        self.start_space_high = np.array([-5, -5])
        self.periodic_dimensions = []
        self.horizon = 30
        self.reward_bounds = [-35, 0]
        self.uniform_start = uniform_start
        self.p0_high = np.array([-6, -6])
        self.p0_low = np.array([-8, -9])
        self.easy = easy
        self.shifted_low = np.array([1, 4])
        self.shifted_high = np.array([3, 7])
        self.shifted = shifted

    def reset(self, obs=None):
        if obs is not None:
            self.x = obs
        elif self.shifted:
            self.x = np.random.uniform(self.shifted_low, self.shifted_high)
        elif self.uniform_start:
            self.x = self.observation_space.sample()
        else:
            self.x = np.random.uniform(self.p0_low, self.p0_high)
        return self.x

    def get_B(self):
        # just some arbitrary continuous function from state to 2x2 mx
        if self.easy:
            x_gain = np.sin(self.x[1] / 10) + 4
            y_gain = np.cos(self.x[0] / 10) * 1.5 - 2
            scaling = np.array([[x_gain, 0], [0, y_gain]])
        else:
            x_gain = np.sin(self.x[1] / 10) * 4
            y_gain = np.cos(self.x[0] / 10) * 4
            scaling = np.array([[x_gain, y_gain / 5], [x_gain / 7, y_gain]])
        return scaling

    def step(self, action):
        B = self.get_B()
        delta_x = B @ action
        self.x += delta_x
        self.x = np.clip(
            self.x, self.observation_space.low, self.observation_space.high
        )
        rew = _weird_gain_rew(self.x)
        return self.x, rew, False, {}



def _weird_gain_rew(x):
    return -np.sum(np.abs(x - GOAL), axis=-1)


def weird_gain_reward(x, next_obs):
    return _weird_gain_rew(next_obs)
