import collections
import os

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite import fish
from dm_control.utils import rewards
from dm_control.utils import io as resources
import numpy as np

_TASKS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tasks')

_DEFAULT_TIME_LIMIT = 40
_CONTROL_TIMESTEP = .04
_JOINTS = ['tail1',
           'tail_twist',
           'tail2',
           'finright_roll',
           'finright_pitch',
           'finleft_roll',
           'finleft_pitch']


def get_model_and_assets():
    """Returns a tuple containing the model XML string and a dict of assets."""
    return resources.GetResource(os.path.join(_TASKS_DIR, 'fish.xml')), common.ASSETS


@fish.SUITE.add('custom')
def obstacles(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
  """Returns the Fish Obstacles task."""
  physics = fish.Physics.from_xml_string(*get_model_and_assets())
  task = Obstacles(random=random)
  environment_kwargs = environment_kwargs or {}
  return control.Environment(
      physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
      **environment_kwargs)


class Obstacles(fish.Swim):
  """A custom Fish Obstacles task."""

  def __init__(self, random=None):
    super().__init__(random=random)

  def in_wall(self, physics, name, min_distance=0.08):
    """Returns True if the given body is too close to a wall."""
    for wall in ['wall0', 'wall1', 'wall2', 'wall3']:
      l1_dist = np.min(np.abs(physics.named.data.geom_xpos[name][:2] - physics.named.data.geom_xpos[wall][:2]))
      if l1_dist < min_distance:
        return True
    return False

  def initialize_episode(self, physics):
    in_wall = True
    while in_wall:
        # Randomize fish position.
        quat = self.random.randn(4)
        physics.named.data.qpos['root'][3:7] = quat / np.linalg.norm(quat)
        for joint in _JOINTS:
            physics.named.data.qpos[joint] = self.random.uniform(-.2, .2)
        # Randomize target position.
        physics.named.model.geom_pos['target', 'x'] = self.random.uniform(-.4, .4)
        physics.named.model.geom_pos['target', 'y'] = self.random.uniform(-.4, .4)
        physics.named.model.geom_pos['target', 'z'] = self.random.uniform(.1, .3)
        # Make sure target is not too close to a wall.
        physics.after_reset()
        in_wall = self.in_wall(physics, 'target')
    base.Task.initialize_episode(self, physics)

  def get_reward(self, physics):
    radii = physics.named.model.geom_size[['mouth', 'target'], 0].sum()
    in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
                                  bounds=(0, radii), margin=2*radii)
    is_upright = 0.5 * (physics.upright() + 1)
    is_not_in_wall = 1. - self.in_wall(physics, 'torso', min_distance=0.06)
    return is_not_in_wall * (7*in_target + is_upright) / 8
