import pandas as pd
import numpy as np
import gymnasium as gym
from sklearn.neighbors import KernelDensity
import torch

# This environment is a modified version of the pusher environment from gym
DEFAULT_CAMERA_CONFIG = {
    "trackbodyid": -1,
    #"distance": 4.0,
    "distance": 3.0,
    "azimuth": 135.0,
    "elevation": -22.5,
}

class Environment(gym.Env):
      metadata = {"render_modes": ["rgb_array", "human"]}

      def __init__(self, radius = 0.05, render_mode = None):
            # fix the initial state of the arm 
            self.env = gym.make("Pusher-v4", render_mode = render_mode)
            # self.env.reset()
            # self.env.render()
            # self.env.unwrapped.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
            # self.env.unwrapped.mujoco_renderer._set_cam_config()
            self.env.unwrapped.model.opt.timestep = 0.05 # set the time step to be 0.05
            self.max_t = 10

            self.observation_space = gym.spaces.Box(shape=(19 + self.max_t - 1,), low = -float("inf"), high = float("inf"), dtype=np.float32)
            self.action_space = gym.spaces.Box(shape=(5,),low = -1, high = 1, dtype=np.float32)
            self.render_mode = render_mode
            self.state = np.zeros(20, dtype = np.float32)
            self.info = {}

            # start at 0.35, -0.15
            self.target_x, self.target_y = 0.45, -0.05
            self.target_x2, self.target_y2 = 0.25, -0.05
            self.target_x3, self.target_y3 = 0.35, -0.05
            
            self.radius = radius

            self.reset()

      def reset(self, seed = None, options = None):
            super().reset(seed = seed)

            self.t = 0

            obs, _ = self.env.reset()

            # joint_pos + cylinder_pos + goal_pos
            qpos = self.env.unwrapped.init_qpos
            goal_pos = np.array([0, 0])
            cylinder_pos = np.array([-0.1, -0.1])
            # big arm angle, big arm slope, small arm angle, small arm slope, palm angle
            joint_pos = np.array([0, 0, 0, 0, 0, 0, 0]) # 0.75, 1.3, -.5, -1.0, 0, 0, 0


            qpos[0:-4] = joint_pos
            qpos[-4:-2] = cylinder_pos
            qpos[-2:] = goal_pos
            qvel = self.env.unwrapped.init_qvel
            qvel = np.zeros_like(qvel) # set the initial velocity to be 0

            self.env.unwrapped.set_state(qpos, qvel)
            
            obs = self.env.unwrapped._get_obs()
            self.state = obs[:19]

            # x, y, z = self.state[17:20] # the position of the cylinder
            # x2, y2, z2 = self.state[14:17] # the position of the pointer
            # reward_dist = np.sqrt((x - self.target_x) ** 2 + (y - self.target_y) ** 2)  #+ (z - self.target_z) ** 2)
            # reward_dist2 = np.sqrt((x - self.target_x2) ** 2 + (y - self.target_y2) ** 2)  #+ (z - self.target_z2) ** 2)
            # reward_near = np.sqrt((x-x2)**2 + (y-y2)**2+ (z-z2)**2)  # the size of the cylinder is 0.05, 0.05, 0.05
            # self.base_augmented_reward = 6 * np.exp(-3 * reward_near) +  100 * (np.exp(-100 * reward_dist)+ np.exp(- 100 * reward_dist2)) 

            return self._get_obs(), self.info
      
      # ending state: the one before the time out, or the one that is within the radius distance from one of the target
      def step(self, a):
            exit_flag = a[0]
            
            # add 3 0 zeros
            a2 = np.concatenate([a[1:] * 2, np.zeros(3)])

            x, y = self.state[17:19] # the position of the cylinder

            reward_dist = np.sqrt((x - self.target_x) ** 2 + (y - self.target_y) ** 2)  #+ (z - self.target_z) ** 2)
            reward_dist2 = np.sqrt((x - self.target_x2) ** 2 + (y - self.target_y2) ** 2)  #+ (z - self.target_z2) ** 2)
            reward_dist3 = np.sqrt((x - self.target_x3) ** 2 + (y - self.target_y3) ** 2)  #+ (z - self.target_z) ** 2)
            
            if exit_flag>0.5 or reward_dist < self.radius or reward_dist2 < self.radius or reward_dist3 < self.radius or self.t + 1>= self.max_t:
                  # print(a)
                  
                  # final_reward for GFN, when the distance is less than the radius
                  final_reward = max(50000*((reward_dist < self.radius) + (reward_dist2 < self.radius) + (reward_dist3 < self.radius)), 1e-10) 

                  # augmented_reward = (np.exp(-100 * reward_dist)+ np.exp(- 100 * reward_dist2) + np.exp(- 100 * reward_dist3))
                  augmented_reward = 1

                  return self._get_obs(), final_reward, True, False, {'augmented_rew': augmented_reward}

            obs, _, done, _, _ = self.env.step(a2)

            if done:
                  print("Error in the pusher env, please check!")
            
            self.t += 1
            self.state = obs[:19]
            return self._get_obs(), 0, False, False, {}
      
      def _get_obs(self):
            t_encode = np.where(np.arange(self.max_t-1) == self.t, 1, 0)
            return np.concatenate([self.state, t_encode]).astype(np.float32)
      
      def _get_info(self):
            return self.info
      
      def render(self, mode = "human"):
            return self.env.render(mode = mode)
      
      def close(self):
            self.env.close()
      
      def get_state(self, obs):
            return obs
      
      def is_within_target(self, data, x, y): 
            reward_dist = np.sqrt((data[:,0] - x) ** 2 + (data[:,1] - y) ** 2) 
            return reward_dist < self.radius
      
      def get_error(self, samples): # failure rate
            samples = np.array(samples)[:, 17:19]
            # return kl_divergence.sum().item()/len(self.data)
            succeed = self.is_within_target(samples, self.target_x, self.target_y)
            succeed2 = self.is_within_target(samples, self.target_x2, self.target_y2)
            succeed3 = self.is_within_target(samples, self.target_x3, self.target_y3)

            rates = np.array([
                  succeed.sum()  / len(samples),
                  succeed2.sum() / len(samples),
                  succeed3.sum() / len(samples),
            ])
            # if any rate is zero, harmonic mean is zero (and avoids division by zero)
            # if np.any(rates == 0):
            #       return 0.0
            # otherwise harmonic mean = n / sum(1/r_i)
            rates = rates.clip(max = 1/3)
            return np.sum(rates)
           
if __name__ == "__main__":
    from tqdm import trange
    # =========================================================================
    # Show sparsity
    np.random.seed(42)
    num_samples = 100000
    env = Environment()
    rs = []
    success = [0, 0, 0]
    for _ in trange(num_samples):
        done = False
        ep_t = 0
        while not done:
            # sample from normal distribution
            action = np.random.normal(size=(4,))/2
            # if any action is > 2 or < -2, sample that dimension again
            while np.any(action > 1) or np.any(action < -1):
                ind = np.where(((action > 1) | (action < -1)) == 1)
                action[ind] = np.random.normal(size=(len(ind[0]),))/2
            exit_flag = (ep_t == env.unwrapped.max_t - 1) * np.ones((1,))
            action = np.concatenate((exit_flag, action))
            s, r, done, _, _ = env.step(action)

            if done:
                  x, y = s[17:19] # the position of the cylinder
                  if (x-env.target_x) ** 2 + (y - env.target_y) ** 2 < env.radius ** 2:
                        success[0] += 1
                  if (x-env.target_x2) ** 2 + (y - env.target_y2) ** 2 < env.radius ** 2:
                        success[1] += 1
                  if (x-env.target_x3) ** 2 + (y - env.target_y3) ** 2 < env.radius ** 2:
                        success[2] += 1
            
            ep_t += 1
        rs.append(r)
        env.reset()
    print("Success rate: ", success[0], success[1], success[2])
    rs = np.array(rs)
    print(np.sum(rs > 1e-3))