import numpy as np
import gym
import copy


class ParticleEnv(gym.Env):
  def __init__(self, max_steps=50, dims=2, delta=1., delta_goal=0.1, discountinuities=True, multimodal=True): 
    self.max_steps = max_steps
    self.dims = dims
    self.delta = 1.
    self.delta_goal = delta_goal
    self.eps_goal = 1e-2
    self.multimodal = multimodal
    self.discountinuities = discountinuities

  def reset(self):
    self.state = np.random.uniform(-self.delta, self.delta, size=self.dims) 
    self.goal1 = np.random.uniform(-self.delta_goal, self.delta_goal, size=self.dims) 
    self.goal2 = np.random.uniform(-self.delta_goal, self.delta_goal, size=self.dims) 
    self.goal = copy.deepcopy(self.goal1) if (np.random.rand() <= 0.5 and self.multimodal) else copy.deepcopy(self.goal2)
    if np.linalg.norm(self.goal1) > np.linalg.norm(self.goal2) and self.discountinuities:
      self.goal_values = np.array([1., -1.])
    else:
      self.goal_values = np.array([-1., 1.])
    self.total_steps = 0
    return np.concatenate([copy.deepcopy(self.state),
                           copy.deepcopy(self.goal1), 
                           copy.deepcopy(self.goal2)], -1)

  def compute_reward(self, state):
    distances = np.array([np.linalg.norm(state - g) for g in [self.goal1, self.goal2]])
    close = np.float32(distances < self.eps_goal)
    return close @ self.goal_values
  
  def get_action(self):
    action = copy.deepcopy(self.goal) - copy.deepcopy(self.state)
    action = np.clip(action, -1, 1)
    return action

  def step(self, action):
    self.total_steps += 1
    assert self.state.shape == action.shape
    next_state = copy.deepcopy(self.state) + copy.deepcopy(action)
    next_state = np.clip(next_state, -self.delta, self.delta)

    reward = self.compute_reward(next_state)
    done = (reward != 0) or (self.total_steps > self.max_steps)
    self.state = copy.deepcopy(next_state)
    next_state = np.concatenate([copy.deepcopy(self.state),
                                 copy.deepcopy(self.goal1), 
                                 copy.deepcopy(self.goal2)], -1)
    return next_state, reward, done, {}