import os

import numpy as np

from .muscle_arm import MuscleArm


class Arm750(MuscleArm):
    def __init__(self, identifier=None):
        self.model_type = "arm750"
        self.nq = 21
        # path = os.path.dirname((os.path.dirname(os.path.abspath(__file__))))
        # self.o2m_params = os.path.join(path, 'param_files/arm750_o2m.pt')
        super(Arm750, self).__init__(identifier)
        self.set_gravity([0, 0, -10])
        self.interesting_targets = {
            "chest_height": [0.15, 0, 0.9],
            "above_shoulder": [0.05, -0.3, 1.1],
            "medium_height": [0.15, -0.5, 0.9],
        }
        # self.render_substep()

    def reset_model(self):
        self.apply_muscle_settings()
        self.randomise_init_state()
        if self.random_goals:
            # self.target = np.random.uniform([0.35, -0.1, 1.1], [0.36, 0.1, 1.15])
            self.target = np.random.uniform([0.35, -0.3, 0.75], [0.45, 0.0, 1.00])
        self.sim.data.qpos[-3:] = self.target
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 1
        self.viewer.cam.distance = self.model.stat.extent * 0.5
        self.viewer.cam.lookat[:] = [1, -0.5, 1.35]
        self.viewer.cam.elevation = -30
        self.viewer.cam.azimuth = 150

    def randomise_init_state(self):
        """
        Randomises initial joint positions slightly.
        """
        # TODO remove uncontrollable joint from randomisation
        qpos = self.init_qpos
        qpos[: self.nq] = (
            self.init_qpos[: self.nq]
            + self.np_random.normal(0.0, 0.01, size=(self.model.nq))[: self.nq]
        )
        qvel = self.init_qvel
        qvel[: self.nq] = (
            self.init_qvel[: self.nq]
            # + self.np_random.normal(0.0, 0.03, size=(self.model.nq))[: self.nq]
            + self.np_random.normal(0.0, 0.03, size=(self.model.nq))[: self.nq]
        )
        self.set_state(qpos, qvel)

    def _get_reward(self, ee_pos, action):
        lamb = 1e-4  # 1e-4
        epsilon = 1e-4
        log_weight = 1.0
        rew_weight = 0.1

        d = np.mean(np.square(ee_pos - self.target))
        activ_cost = lamb * np.mean(np.square(action))
        if self.sparse_reward:
            return -1.0  # - activ_cost
        return (
            -rew_weight * (d + log_weight * np.log(d + epsilon ** 2)) - activ_cost - 2
        )
