import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from copy import deepcopy

class PointObstaclesEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self):
        self.count = 0
        self.mx = 0
        self.my = 20
        self.change_flag1 = 0
        self.change_flag2 = 0
        self.change_flag3 = 0
        self.change_flag4 = 0
        self.dist_flag1 = 0.7
        self.dist_flag2 = 0.5
        self.dist_flag3 = 0.9
        self.dist_flag4 = 1.1
        self.monster_collided = False
        self.global_monster_collided = False
        self.realgoal = np.array([0,1])
        self.goal = np.array([60., 10., 0.0])
        mujoco_env.MujocoEnv.__init__(self, 'point_obstacles.xml', 5)
        utils.EzPickle.__init__(self)
        self.randomizeCorrect()

    def randomizeCorrect(self):
        self.realgoal = np.array([self.np_random.choice([0, 1]), self.np_random.choice([0, 1])])
        # 0 = obstacle. 1 = no obstacle.
        # self.realgoal = 0

    def step(self, a):
        current_action = (a[0]+1)/2.0
        # print(self.init_qpos)
        # monster_pos = self.get_body_com("monster")
        # monster_pos[1] += 1
        self.count += 1
        # posafter = self.get_body_com("torso")
        # if self.count % 200 == 0:
        #     n_qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        #     n_qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        #     n_qpos[:2] = self.data.qpos[:2]
        #     self.set_state(n_qpos, n_qvel)

        # if np.sum(np.square(self.data.qpos[:2] - np.array([0,20]))) < 100:
        #     self.mx += np.sign(self.data.qpos[0] - self.mx)
        #     self.my += np.sign(self.data.qpos[1] - self.my)
        #     # self.mx = self.data.qpos[0]
        #     # self.my = self.data.qpos[1]

        #     # n_qpos = np.copy(self.data.qpos[:])
        #     # n_qpos[-2:] = np.array([self.mx,self.my])
        #     # self.set_state(n_qpos, self.data.qvel[:])
        #     # print(self.data.qpos[:2],self.mx,self.my)
        # else:
        #     self.mx = 0
        #     self.my = 20

        # print(np.square(self.data.qpos[:2] - np.array([0,20])))
        n_qpos = np.copy(self.data.qpos[:])

        if not self.change_flag1 and n_qpos[4] >= 5:
            self.dist_flag1 = -self.dist_flag1
            self.change_flag1 = 1
        if self.change_flag1 and n_qpos[4] < -27:
            self.dist_flag1 = -self.dist_flag1
            self.change_flag1 = 0

        if not self.change_flag2 and n_qpos[6] >= 5:
            self.dist_flag2 = -self.dist_flag2
            self.change_flag2 = 1
        if self.change_flag2 and n_qpos[6] < -27:
            self.dist_flag2 = -self.dist_flag2
            self.change_flag2 = 0

        if not self.change_flag3 and n_qpos[8] >= 5:
            self.dist_flag3 = -self.dist_flag3
            self.change_flag3 = 1
        if self.change_flag3 and n_qpos[8] < -27:
            self.dist_flag3 = -self.dist_flag3
            self.change_flag3 = 0

        if not self.change_flag4 and n_qpos[10] >= 5:
            self.dist_flag4 = -self.dist_flag4
            self.change_flag4 = 1
        if self.change_flag4 and n_qpos[10] < -27:
            self.dist_flag4 = -self.dist_flag4
            self.change_flag4 = 0
        # for inner_num in [4,6,8,10,12]:
        n_qpos[4] += self.dist_flag1# * np.random.uniform(0,2)
        n_qpos[6] += self.dist_flag2# * np.random.uniform(0,2)
        n_qpos[8] += self.dist_flag3# * np.random.uniform(0,2)
        n_qpos[10] += self.dist_flag4# * np.random.uniform(0,2)

        monster_pos1 = self.get_body_com("monster1")
        monster_pos2 = self.get_body_com("monster2")
        monster_pos3 = self.get_body_com("monster3")
        monster_pos4 = self.get_body_com("monster4")
        monster_ypos1 = np.round(monster_pos1[1], 3).copy()
        monster_ypos2 = np.round(monster_pos2[1], 3).copy()
        monster_ypos3 = np.round(monster_pos3[1], 3).copy()
        monster_ypos4 = np.round(monster_pos4[1], 3).copy()
        monster_pos = deepcopy(np.hstack((monster_ypos1, monster_ypos2, monster_ypos3, monster_ypos4)).ravel())
        
        point_pos = self.get_body_com("torso")
        pos_before = deepcopy(point_pos)
        min_dist = 100

        # Collision with monster
        if np.linalg.norm(point_pos - monster_pos1) < 5 or np.linalg.norm(point_pos - monster_pos2) < 5 or np.linalg.norm(point_pos - monster_pos3) < 5 or np.linalg.norm(point_pos - monster_pos4) < 5:
            n_qpos[:2] = np.array([0, 0])
        else:
            n_qpos[0] += 1.5*current_action
        #     self.monster_collided = True
        # else:
        #     self.monster_collided = False

        # if (np.linalg.norm(point_pos - np.array([0, 0, 0])) > 2) and self.monster_collided == True:
        #     self.global_monster_collided = True

        # if not self.global_monster_collided:
        #     n_qpos[0] += 1.5*current_action

        # if self.monster_collided == True:
        #     n_qpos[0] -= 1.5

        if n_qpos[0] > 100:
            n_qpos[0] = 100

        # n_qpos[-2:] = np.array([self.mx,self.my])
        self.set_state(n_qpos, self.data.qvel[:])
        self.do_simulation(np.zeros(2), self.frame_skip)
        done = False
        point_pos_after = self.get_body_com("torso")
        pos_after = deepcopy(point_pos_after)

        # if np.linalg.norm(point_pos[0] - np.array([60])) < 0.5:
        #     reward = 0
        #     done = True
        # else:
        #     reward = -1
        reward = self.compute_reward(point_pos, self.goal.copy())

        # if np.abs(np.linalg.norm(pos_after) - np.linalg.norm(pos_before)) > 1.0:
        #   reward = 0
        # else:
        #   reward = -1

        #
        # print(np.square(np.sum(self.data.qpos[:2] - np.array([50,50]))))

        # if np.sum(np.square(self.data.qpos[:2] - np.array([38,38]))) < 4:
        #     reward = 100
        #     done = True
        # else:
        #     reward = 0
        ob = self._get_obs()
        info = {
            'position': deepcopy(pos_after[:2]),
            'pos_before': deepcopy(pos_before[:2]),
            'pos_after': deepcopy(pos_after[:2]),
            'monster_pos': deepcopy(monster_pos)
        }
        return ob, reward, done, info

    def _get_obs(self):
        # return np.concatenate([
        #     self.data.qpos.flat[2:],
        #     self.data.qvel.flat,
        # ])
        # return np.concatenate([
        #     self.data.qpos.flat,
        #     self.data.qvel.flat,
        # ])
        # print(self.sim.data.cfrc_ext)
        monster_pos1 = self.get_body_com("monster1")
        monster_pos2 = self.get_body_com("monster2")
        monster_pos3 = self.get_body_com("monster3")
        monster_pos4 = self.get_body_com("monster4")
        point_pos = self.get_body_com("torso")
        posafter = deepcopy(point_pos)
        point_xpos = np.round(point_pos[0], 3).copy()
        monster_ypos1 = np.round(monster_pos1[1], 3).copy()
        monster_ypos2 = np.round(monster_pos2[1], 3).copy()
        monster_ypos3 = np.round(monster_pos3[1], 3).copy()
        monster_ypos4 = np.round(monster_pos4[1], 3).copy()
        monster_pos = deepcopy(np.hstack((monster_ypos1, monster_ypos2, monster_ypos3, monster_ypos4)).ravel())
        # obs_temp = deepcopy(np.hstack((point_xpos, monster_ypos1, monster_ypos2, monster_ypos3, monster_ypos4)).ravel())
        obs_temp = deepcopy(np.array([point_xpos]))
        goal = self.goal
        return {
            'observation': obs_temp,
            'achieved_goal': point_pos.copy(),
            'desired_goal': goal.copy(),
            'extra_obs': monster_pos.copy()
        }
        return obs_temp

        # return np.concatenate([
        #     self.sim.data.qpos.flat[2:-2],
        #     self.sim.data.qvel.flat[:-2],
        #     np.clip(self.sim.data.cfrc_ext[:-1], -1, 1).flat,
        # ])

    def compute_reward(self, achieved_goal, goal, info=None, reward_type='sparse'):
        self.distance_threshold = 0.5
        # Compute distance between goal and the achieved goal.
        # d = np.linalg.norm(achieved_goal[0] - goal[0])
        # d = self.goal_distance(np.array([achieved_goal[0]]), np.array([goal[0]]))
        reward_type = 'dense'
        d = self.goal_distance(achieved_goal, goal)
        # print(d)
        if reward_type == 'sparse':
            return -(d > self.distance_threshold).astype(np.float32)
        else:
            return -d/100.

    def goal_distance(self, goal_a, goal_b):
        assert goal_a.shape == goal_b.shape
        return np.linalg.norm(goal_a - goal_b, axis=-1)

    def get_direction_taken(self, action, info=None):
        # Normalize lcode from -1 to 1
        # Return speed
        pos_before = info['pos_before'].copy()
        pos_after = info['pos_after'].copy()

        x1, y1 = pos_before
        x2, y2 = pos_after
        return np.abs((x2 - x1)/ 1.5)

        # # degree = np.degrees(np.arctan2( y2, x2 ))
        # degree = np.degrees(np.arctan2( y2 - y1, x2 - x1 ))
        # # x = (y2-y1) / (x2-x1+0.000001)
        # if degree > 0:
        #   slope_chosen = degree
        # else:
        #   slope_chosen = (360 + degree)
        # # Normalize from -1 to 1
        # slope_chosen = (slope_chosen/180.) - 1

        # # reward = ((-1000.*0.1*(np.square(np.linalg.norm(slope_chosen-target))))/20 + 5.)# * 0.2
        # return slope_chosen

    def setIndex(self, index):
        self.count = 0
        self.mx = 0
        self.my = 20
        # self.dist_flag = 0.2
        qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.count = 0
        self.mx = 0
        self.my = 20
        self.monster_collided = False
        self.global_monster_collided = False
        # self.dist_flag = 0.2
        qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.8
