import numpy as np
import matplotlib.pyplot as plt

from pybulletgym.envs.roboschool.robots.robot_bases import MJCFBasedRobot
from pybulletgym.envs.roboschool.envs.env_bases import BaseBulletEnv
from pybulletgym.envs.roboschool.scenes.scene_bases import SingleRobotEmptyScene

from tasks.task import Task


class Reacher(Task):
    
    def __init__(self, target_positions, task_index, *args,
                 repair_cost=3., fail_prob_base=0.035, fail_radius_large=0.06, fail_radius_small=0.005,
                 action_noise=0.03, view=False, risky_targets=[], include_target_in_state=False, **kwargs):
        self.target_positions = target_positions
        self.task_index = task_index
        self.target_pos = np.array(target_positions[task_index])
        self.risky_targets = risky_targets
        self.env = ReacherBulletEnv(self.target_pos)
        self.view = view
        self.include_target_in_state = include_target_in_state
        
        # risk
        self.repair_cost = repair_cost
        self.fail_prob_base = fail_prob_base
        self.fail_radius_large = fail_radius_large
        self.fail_radius_small = fail_radius_small
        self.action_noise = action_noise
        
        # make the action lookup from integer to real action
        actions = [-1., 0., 1.]
        self.action_dict = dict()
        for a1 in actions:
            for a2 in actions:
                self.action_dict[len(self.action_dict)] = (a1, a2)
                
    def clone(self):
        return Reacher(self.target_positions, self.task_index,
                       repair_cost=self.repair_cost, fail_prob_base=self.fail_prob_base, 
                       fail_radius_large=self.fail_radius_large, fail_radius_small=self.fail_radius_small,
                       action_noise=self.action_noise, risky_targets=self.risky_targets,
                       include_target_in_state=self.include_target_in_state)
    
    def initialize(self):
        if self.view: 
          self.env.render('human')
        self.state = self.env.reset()
        if self.include_target_in_state:
            return np.concatenate([self.state.flatten(), self.target_pos])
        else:
            return self.state
    
    def action_count(self):
        return len(self.action_dict)

    def fail_prob(self, delta):
        if delta < self.fail_radius_large:
            return self.fail_prob_base
        else:
            return 0.0
        
    def transition(self, action):
        
        tip_pos = np.array(self.env.robot.fingertip.pose().xyz()[:2])
        d_to_target = np.min([np.linalg.norm(tip_pos - np.array(target_pos)) for target_pos in self.risky_targets])
        
        real_action = tuple(np.array(self.action_dict[action]) + np.random.normal(
            loc=0., scale=self.action_noise, size=(2,)))
        real_action = np.clip(real_action, -1., 1.)
        
        new_state, reward, done, flag = self.env.step(real_action)
        self.state = new_state
        
        fail = False
        if np.random.rand() <= self.fail_prob(d_to_target):
           fail = True
           reward = -self.repair_cost
        other = (d_to_target, (tip_pos[0], tip_pos[1]))
        
        if self.include_target_in_state:
            return_state = np.concatenate([self.state, self.target_pos])
        else:
            return_state = self.state
            
        return return_state, reward, False, (fail, other)
    
    # ===========================================================================
    # STATE ENCODING FOR DEEP LEARNING
    # ===========================================================================
    def encode(self, state):
        return np.array(state).reshape((1, -1))
    
    def encode_dim(self):
        if self.include_target_in_state:
            return 6
        else:
            return 4
    
    # ===========================================================================
    # SUCCESSOR FEATURES
    # ===========================================================================
    def features(self, state, action, next_state, noise):
        phi = np.zeros((len(self.target_positions) + 1,))
        if noise[0]:
            phi[-1] = 1.
        else:
            for index, target in enumerate(self.target_positions):
                delta = np.linalg.norm(np.array(self.env.robot.fingertip.pose().xyz()[:2]) - np.array(target))
                phi[index] = 1. - 4. * delta
        return phi
    
    def feature_dim(self):
        return len(self.target_positions) + 1
    
    def get_w(self):
        w = np.zeros((len(self.target_positions) + 1, 1))
        w[self.task_index, 0] = 1.
        w[-1, 0] = -self.repair_cost
        return w


class ReacherTempViewer:
    
    def __init__(self, task):
        self.task = task
        self.temp_hist = []
        self.fig = plt.figure()
        self.ax = plt.gca()
        self.ax.set_ylim([0., 2.0])
        self.ax.set_xlim([0, 200])
        self.data, = self.ax.plot(np.array([0.]), np.array([0.]), color='black')
        self.fig.canvas.draw()
        plt.show(block=False)
    
    def update(self):
        self.temp_hist.append(self.task.state[4])
        if len(self.temp_hist) > 200:
            self.temp_hist = self.temp_hist[-200:]
        time = list(range(len(self.temp_hist)))
        self.data.set_data(time, self.temp_hist)
        # self.ax.autoscale_view(True, True)
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
    

class ReacherBulletEnv(BaseBulletEnv):

    def __init__(self, target):
        self.robot = ReacherRobot(target)
        BaseBulletEnv.__init__(self, self.robot)

    def create_single_player_scene(self, bullet_client):
        return SingleRobotEmptyScene(bullet_client, gravity=0.0, timestep=0.0165, frame_skip=1)

    def step(self, a):
        assert (not self.scene.multiplayer)
        self.robot.apply_action(a)
        self.scene.global_step()

        state = self.robot.calc_state()  # sets self.to_target_vec
        
        delta = np.linalg.norm(
            np.array(self.robot.fingertip.pose().xyz()) - np.array(self.robot.target.pose().xyz()))
        reward = 1. - 4. * delta
        self.HUD(state, a, False)
        
        return state, reward, False, {}

    def camera_adjust(self):
        x, y, z = self.robot.fingertip.pose().xyz()
        x *= 0.5
        y *= 0.5
        self.camera.move_and_look_at(0.3, 0.3, 0.3, x, y, z)


class ReacherRobot(MJCFBasedRobot):
    TARG_LIMIT = 0.27

    def __init__(self, target):
        MJCFBasedRobot.__init__(self, 'reacher.xml', 'body0', action_dim=2, obs_dim=4)
        self.target_pos = target

    def robot_specific_reset(self, bullet_client):
        self.jdict["target_x"].reset_current_position(self.target_pos[0], 0)
        self.jdict["target_y"].reset_current_position(self.target_pos[1], 0)
        self.fingertip = self.parts["fingertip"]
        self.target = self.parts["target"]
        self.central_joint = self.jdict["joint0"]
        self.elbow_joint = self.jdict["joint1"]
        self.central_joint.reset_current_position(self.np_random.uniform(low=-3.14, high=3.14), 0)
        self.elbow_joint.reset_current_position(self.np_random.uniform(low=-3.14 / 2, high=3.14 / 2), 0)

    def apply_action(self, a):
        assert (np.isfinite(a).all())
        self.central_joint.set_motor_torque(0.05 * float(np.clip(a[0], -1, +1)))
        self.elbow_joint.set_motor_torque(0.05 * float(np.clip(a[1], -1, +1)))

    def calc_state(self):
        theta, self.theta_dot = self.central_joint.current_relative_position()
        self.gamma, self.gamma_dot = self.elbow_joint.current_relative_position()
        # target_x, _ = self.jdict["target_x"].current_position()
        # target_y, _ = self.jdict["target_y"].current_position()
        self.to_target_vec = np.array(self.fingertip.pose().xyz()) - np.array(self.target.pose().xyz())
        return np.array([
            theta,
            self.theta_dot,
            self.gamma,
            self.gamma_dot
        ])
# 
#     def calc_potential(self):
#         return -100 * np.linalg.norm(self.to_target_vec)

