import collections
import os

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
import numpy as np

from ostrichrl.tasks.ostrich import SUITE


class Physics(mujoco.Physics):
    def __init__(self, *args, **kwargs):
        if not hasattr(self, 'force_scale'):
            self.force_scale = 0.003
        super().__init__(*args, **kwargs)
        self.init_height = 1.05

    def height_diff(self):
        diff = self.named.data.geom_xpos['pelvis', 'z'].copy() - self.init_height
        diff = 0.
        return diff

    def merge_args(self, env_args):
        for k, v in env_args.items():
            setattr(self, k, v)

    def control_vector(self):
        return self.data.ctrl.copy()

    def qpos_without_x(self):
        return self.data.qpos.copy()[1:-1]

    def qvel(self):
        return np.clip(self.data.qvel[:-1], -100, 100)

    def pelvis_height(self):
        return self.named.data.geom_xpos['pelvis', 'z'].copy() - self.height_diff()

    def feet_height(self):
        return np.array([self.named.data.xpos['r_pes', 'z'],
                         self.named.data.xpos['l_pes', 'z']]) - self.height_diff()

    def head_height(self):
        return self.named.data.xpos['head', 'z'].copy() - self.height_diff()

    def muscle_lengths(self):
        return self.data.actuator_length.copy()

    def muscle_velocities(self):
        return np.clip(self.data.actuator_velocity, -100, 100)

    def tendon_states(self):
        lce = self.muscle_lengths()
        f = self.muscle_forces()
        if not hasattr(self, "max_muscle"):
            self.max_muscle = np.zeros_like(lce)
            self.min_muscle = np.ones_like(lce) * 100.0
            self.max_force = - np.ones_like(f) * 100.0
            self.min_force = np.ones_like(f) * 100.0
        self.max_muscle = np.maximum(lce, self.max_muscle)
        self.min_muscle = np.minimum(lce, self.min_muscle)
        self.max_force = np.maximum(f, self.max_force)
        self.min_force = np.minimum(f, self.min_force)
        return 1.0 * (
                ((lce - self.min_muscle) / (self.max_muscle - self.min_muscle + 0.1)) - 0.5
        ) * 2.0 + self.force_scale * (
                       (f - self.min_force) / (self.max_force - self.min_force + 0.1)
               )

    def muscle_activations(self):
        return np.clip(self.data.act, -100, 100)

    def muscle_forces(self):
        return np.clip(self.data.actuator_force / 1000, -100, 100)

    def torso_angle(self):
        return self.data.qpos[4]

    def horizontal_velocity(self):
        return self.named.data.sensordata['torso_subtreelinvel'][0]


class RunObstacles(base.Task):
    def initialize_episode(self, physics):
        limits = physics.data.model.jnt_range[6:-1]
        physics.data.qpos[6:-1] = self.random.uniform(
            low=limits[:, 0] / 5, high=limits[:, 1] / 5)
        physics.data.qvel[:] = 0

    def after_step(self, physics):
        return

    def get_observation(self, physics):
        obs = collections.OrderedDict()

        obs['head_height'] = physics.head_height()
        obs['pelvis_height'] = physics.pelvis_height()
        obs['feet_height'] = physics.feet_height()
        obs['qpos'] = physics.qpos_without_x()
        obs['qvel'] = physics.qvel()

        obs['muscle_activations'] = physics.muscle_activations()
        obs['muscle_forces'] = physics.muscle_forces()
        obs['muscle_lengths'] = physics.muscle_lengths()
        obs['muscle_velocities'] = physics.muscle_velocities()

        obs['horizontal_velocity'] = physics.horizontal_velocity()

        return obs


    def get_reward(self, physics):
        return physics.horizontal_velocity()

    def get_termination(self, physics):
        if physics.data.qpos[0] > 10:
            return 1
        if physics.torso_angle() < -0.8 or physics.torso_angle() > 0.8:
            return 1
        if physics.head_height() < 0.5:
            return 1
        if physics.head_height() < physics.pelvis_height():
            return 1

@SUITE.add('benchmarking')
def run_stepdown(environment_kwargs=None, random=None):
    task = RunObstacles(random=random)

    path = os.path.dirname(__file__)
    if environment_kwargs and 'obs_height' in environment_kwargs:
        height = environment_kwargs['obs_height']
        path += f'/../../assets/models/ostrich/ostrich_stepdown_{height}.xml'
        environment_kwargs.pop('obs_height')
    else:
        path += '/../../assets/models/ostrich/ostrich_stepdown.xml'
    physics = Physics.from_xml_path(path)
    # environment_kwargs.pop('obs_height')
    environment_kwargs = environment_kwargs or {}
    env = control.Environment(
        physics, task, time_limit=25, control_timestep=0.025,
        **environment_kwargs)

    return env
