import os
import numpy as np
# import tensorflow as tf
from gym import utils
from gym.envs.mujoco import mujoco_env
from gym.envs.registration import EnvSpec


class CrippleAntMOPOEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, cripple_set=[0, 1, 2, 3], extreme_set=[0], goal=30.0/180*np.pi, nocripple=False):
        self.cripple_mask = None
        dir_path = os.path.dirname(os.path.realpath(__file__))
        self._goal = goal
        self.nocripple = nocripple
        mujoco_env.MujocoEnv.__init__(self, "%s/assets/ant.xml" % dir_path, 5)
        self.spec = EnvSpec('Ant-v2')

        self.n_possible_cripple = 4
        self.cripple_mask = np.ones(self.action_space.shape) if self.nocripple else np.ones(self.n_possible_cripple)
        self.cripple_set = cripple_set
        self.extreme_set = extreme_set

        self.cripple_dict = {
            0: [2, 3],  # front L
            1: [4, 5],  # front R
            2: [6, 7],  # back L
            3: [0, 1],  # back R
        }

        self._init_geom_rgba = self.model.geom_rgba.copy()
        self._init_geom_contype = self.model.geom_contype.copy()
        self._init_geom_size = self.model.geom_size.copy()
        self._init_geom_pos = self.model.geom_pos.copy()

        utils.EzPickle.__init__(self, cripple_set, extreme_set)

    def _set_observation_space(self, observation):
        super()._set_observation_space(observation)
        proc_observation = self.obs_preproc(observation[None])
        self.proc_observation_space_dims = proc_observation.shape[-1]

    def step(self, a):
        xy_position_before = self.get_body_com("torso")[:2].copy()
        if self.cripple_mask is None:
            a = a
        else:
            a = self.cripple_mask * a
        self.do_simulation(a, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()
        direct = (np.cos(self._goal), np.sin(self._goal))

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        forward_reward = x_velocity
        angle_reward = np.dot(np.array(xy_velocity), direct)
        ctrl_cost = .5 * np.square(a).sum()
        contact_cost = 0.5 * 1e-3 * np.sum(
            np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
        survive_reward = 1.0
        reward = forward_reward - ctrl_cost - contact_cost + survive_reward
        state = self.state_vector()

        done = False
        ob = self._get_obs()
        return (
            ob,
            reward,
            done,
            dict(
            reward_ctrl=-ctrl_cost,
            reward_contact=-contact_cost,
            reward_survive=survive_reward,
            reward_forward=forward_reward,
            reward_angle=angle_reward,
            x_position=xy_position_after[0],
            y_position=xy_position_after[1])
        )

    def seed(self, seed=None):
        if seed is None:
            self._seed = 0
        else:
            self._seed = seed
        super().seed(seed)

    def _get_obs(self):
        return np.concatenate(
            [
            self.sim.data.qpos.flat[2:],
            self.sim.data.qvel.flat,
            ]
        )

    def obs_preproc(self, obs):
        return obs

    def obs_postproc(self, obs, pred):
        return obs + pred

    def targ_proc(self, obs, next_obs):
        return next_obs - obs

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            size=self.model.nq, low=-0.1, high=0.1
        )
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
        self.set_state(qpos, qvel)
        self.xposbefore = self.get_body_com("torso")[0]

        if not self.nocripple:
            self.change_env()
        return self._get_obs()

    def reward(self, obs, act, next_obs):
        reward_ctrl = 0.0
        vel = (next_obs[..., -3] - obs[..., -3]) / self.dt
        reward_run = vel

        reward_contact = 0.0
        reward_survive = 0.05
        reward = reward_run + reward_ctrl + reward_contact + reward_survive

        return reward

    def tf_reward_fn(self):
        def _thunk(obs, act, next_obs):
            reward_ctrl = 0.0
            vel = (next_obs[..., -3] - obs[..., -3]) / self.dt
            reward_run = vel

            reward_contact = 0.0
            reward_survive = 0.05
            reward = reward_run + reward_ctrl + reward_contact + reward_survive
            return reward

        return _thunk

    def set_crippled_joint(self, value):
        self.cripple_mask = np.ones(self.action_space.shape)
        if value == 0:
            self.cripple_mask[2] = 0
            self.cripple_mask[3] = 0
        elif value == 1:
            self.cripple_mask[4] = 0
            self.cripple_mask[5] = 0
        elif value == 2:
            self.cripple_mask[6] = 0
            self.cripple_mask[7] = 0
        elif value == 3:
            self.cripple_mask[0] = 0
            self.cripple_mask[1] = 0
        elif value == -1:
            pass

        self.crippled_leg = value

        geom_rgba = self._init_geom_rgba.copy()
        if self.crippled_leg == 0:
            geom_rgba[3, :3] = np.array([1, 0, 0])
            geom_rgba[4, :3] = np.array([1, 0, 0])
        elif self.crippled_leg == 1:
            geom_rgba[6, :3] = np.array([1, 0, 0])
            geom_rgba[7, :3] = np.array([1, 0, 0])
        elif self.crippled_leg == 2:
            geom_rgba[9, :3] = np.array([1, 0, 0])
            geom_rgba[10, :3] = np.array([1, 0, 0])
        elif self.crippled_leg == 3:
            geom_rgba[12, :3] = np.array([1, 0, 0])
            geom_rgba[13, :3] = np.array([1, 0, 0])
        self.model.geom_rgba[:] = geom_rgba

        # Make the removed leg not affect anything
        temp_size = self._init_geom_size.copy()
        temp_pos = self._init_geom_pos.copy()

        if self.crippled_leg == 0:
            # Top half
            temp_size[3, 0] = temp_size[3, 0] / 2
            temp_size[3, 1] = temp_size[3, 1] / 2
            # Bottom half
            temp_size[4, 0] = temp_size[4, 0] / 2
            temp_size[4, 1] = temp_size[4, 1] / 2
            temp_pos[4, :] = temp_pos[3, :]

        elif self.crippled_leg == 1:
            # Top half
            temp_size[6, 0] = temp_size[6, 0] / 2
            temp_size[6, 1] = temp_size[6, 1] / 2
            # Bottom half
            temp_size[7, 0] = temp_size[7, 0] / 2
            temp_size[7, 1] = temp_size[7, 1] / 2
            temp_pos[7, :] = temp_pos[6, :]

        elif self.crippled_leg == 2:
            # Top half
            temp_size[9, 0] = temp_size[9, 0] / 2
            temp_size[9, 1] = temp_size[9, 1] / 2
            # Bottom half
            temp_size[10, 0] = temp_size[10, 0] / 2
            temp_size[10, 1] = temp_size[10, 1] / 2
            temp_pos[10, :] = temp_pos[9, :]

        elif self.crippled_leg == 3:
            # Top half
            temp_size[12, 0] = temp_size[12, 0] / 2
            temp_size[12, 1] = temp_size[12, 1] / 2
            # Bottom half
            temp_size[13, 0] = temp_size[13, 0] / 2
            temp_size[13, 1] = temp_size[13, 1] / 2
            temp_pos[13, :] = temp_pos[12, :]

        self.model.geom_size[:] = temp_size
        self.model.geom_pos[:] = temp_pos

    def change_env(self):
        if self.extreme_set == [0]:
            self.crippled_joint = np.array([self.np_random.choice(self.cripple_set)])
        elif self.extreme_set == [1]:
            self.crippled_joint = self.np_random.choice(
                self.cripple_set, 2, replace=False
            )
        else:
            raise ValueError(self.extreme_set)

        self.cripple_mask = np.ones(self.action_space.shape)
        total_crippled_joints = []
        for j in self.crippled_joint:
            total_crippled_joints += self.cripple_dict[j]
        self.cripple_mask[total_crippled_joints] = 0

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5

    def get_sim_parameters(self):
        return np.array([self.crippled_joint]).reshape(-1)

    def num_modifiable_parameters(self):
        return 1

    def log_diagnostics(self, paths, prefix):
        return
