import numpy as np
import pickle
# from .mazerunner import MazeRunnerEnv
from copy import deepcopy
import os
import torch
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
# from PIL import Image


import os
# from gym import utils
# from gym.envs.robotics import rotations, robot_env, fetch_env
# import gym.envs.robotics.utils as robot_utils
from gymnasium import utils
from gymnasium_robotics.utils import rotations
from gymnasium_robotics.envs import robot_env
from gymnasium_robotics.envs.fetch import fetch_env
import gymnasium_robotics.utils.mujoco_py_utils as robot_utils

from raDT.baselines.envs.reach_obstacle import FetchReachObstacleEnv
from raDT.constants import *

# Ensure we get the path separator correct on windows
MODEL_XML_PATH = os.path.join(ENV_PATH, 'assets', 'fetch', 'slide_obstacle.xml')

def goal_distance(goal_a, goal_b):
    assert goal_a.shape == goal_b.shape
    return np.linalg.norm(goal_a - goal_b, axis=-1)


# class FetchReachObstacleEnv(fetch_env.MujocoPyFetchEnv, utils.EzPickle):
#     def __init__(self, reward_type='sparse'):
#         initial_qpos = {
#             'robot0:slide0': 0.4049,
#             'robot0:slide1': 0.48,
#             'robot0:slide2': 0.0,
#         }

#         # get_obs = self._get_obs
#         # self._get_obs = super()._get_obs
#         fetch_env.MujocoPyFetchEnv.__init__(
#             self, model_path=MODEL_XML_PATH, has_object=False, block_gripper=True, n_substeps=20,
#             gripper_extra_height=0.2, target_in_the_air=True, target_offset=0.0,
#             obj_range=0.15, target_range=0.2, distance_threshold=0.05,
#             initial_qpos=initial_qpos, reward_type=reward_type)
#         # self._get_obs = get_obs
#         utils.EzPickle.__init__(self)

#         self.pos_obstacle = self.sim.model.geom_pos[self.sim.model.geom_name2id('obstacle0')]
#         self.size_obstacle = self.sim.model.geom_size[self.sim.model.geom_name2id('obstacle0')]

#         self.cost_threshold = 0.06

#     def _reset_sim(self):
#         self.sim.set_state(self.initial_state)
#         self.initial_gripper_xpos[:3] = [1.34193226,0.74910037,0.555]

#         # Randomize start position of object.
#         if self.has_object:
#             object_xpos = self.initial_gripper_xpos[:2]
            
#             while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
#                 object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range,self.obj_range, size=2)
#             object_qpos = self.sim.data.get_joint_qpos('object0:joint')
#             assert object_qpos.shape == (7,)
#             object_qpos[:2] = object_xpos

#             self.sim.data.set_joint_qpos('object0:joint', object_qpos)
#             self.object_qpos = object_qpos

#         obstacle_xpos = self.initial_gripper_xpos[:2]
#         while np.linalg.norm(obstacle_xpos - self.initial_gripper_xpos[:2]) < 0.1:
#             obstacle_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range,self.obj_range, size=2)
#         obstacle_qpos = self.sim.data.get_site_xpos('obstacle0')
#         obstacle_qpos[:2] = obstacle_xpos
#         obstacle_qpos[2] = 0.42 + self.np_random.uniform(0.03, 0.2)
#         self.sim.model.body_pos[self.sim.model.body_name2id('obstacle0')] = obstacle_qpos

#         self.sim.forward()
#         return True
    
#     def _sample_goal(self):
#         if self.has_object:
#             object_xpos = self.sim.data.get_joint_qpos('object0:joint')[:2]
#             obstacle_xpos = self.sim.data.get_body_xpos('obstacle0')[:2]
#             # print(1,obstacle_xpos)
#             obstacle_xpos_x = obstacle_xpos[0]
#             obstacle_xpos_y = obstacle_xpos[1]
#             obstacle_xpos_z = self.sim.data.get_body_xpos('obstacle0')[2]
#             # print(1,obstacle_xpos_z)
#             object_xpos_x = object_xpos[0]
#             object_xpos_y = object_xpos[1]
#             object_xpos_z = self.sim.data.get_joint_qpos('object0:joint')[2]
#             # print(1,object_xpos_z)
#             # assert 0
#             vector_from_obstacle_to_object_x = (object_xpos_x - obstacle_xpos_x)
#             vector_from_obstacle_to_object_y = (object_xpos_y - obstacle_xpos_y)
#             vector_from_obstacle_to_object_z = (object_xpos_z - obstacle_xpos_z)

#             goal = self.initial_gripper_xpos[:3]
#             while np.linalg.norm(goal[:2] - self.initial_gripper_xpos[:2]) < 0.05 or np.linalg.norm(goal[:2] - object_xpos) < 0.05 or \
#                 np.linalg.norm(obstacle_xpos - goal[:2]) < 0.1 or vector_from_obstacle_to_object_x*(goal[0] - obstacle_xpos_x) > 0 or vector_from_obstacle_to_object_y*(goal[1] - obstacle_xpos_y)>0:
#             # while vector_from_obstacle_to_object_x*(goal[0] - obstacle_xpos_x) > 0 or vector_from_obstacle_to_object_y*(goal[1] - obstacle_xpos_y)>0:
#                 goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
#             goal += self.target_offset
#             goal[2] = 0.42
#             if self.target_in_the_air and self.np_random.uniform() < 0.5:
#                 #while vector_from_obstacle_to_object_z*(goal[2] - obstacle_xpos_z) > 0 :
#                 goal[2] += self.np_random.uniform(0, 0.45)
#         else:
#             obstacle_xpos = self.sim.data.get_body_xpos('obstacle0')[:2]
#             obstacle_xpos_x = obstacle_xpos[0]
#             obstacle_xpos_y = obstacle_xpos[1]
#             vector_from_obstacle_to_object_x = (self.initial_gripper_xpos[0] - obstacle_xpos_x)
#             vector_from_obstacle_to_object_y = (self.initial_gripper_xpos[1] - obstacle_xpos_y)

#             goal = self.initial_gripper_xpos[:3]
#             while np.linalg.norm(obstacle_xpos - goal[:2]) < 0.07 or vector_from_obstacle_to_object_x*(goal[0] - obstacle_xpos_x) > 0 or vector_from_obstacle_to_object_y*(goal[1] - obstacle_xpos_y)>0:
#                 goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
#             goal[2] = self.initial_gripper_xpos[2] + self.np_random.uniform(-0.15, 0.15, size=1)
#         return goal.copy()
    
#     def compute_cost(self, achieved_goal, goal, obs, k, info):
#         # Compute distance between goal and the achieved goal.
#         d = goal_distance(achieved_goal, goal)
#         c = ((obs[0] < obs[10] + (self.size_obstacle[0]) / 2 + k) and (obs[0] > obs[10] - (self.size_obstacle[0]) / 2 - k)\
#              and (obs[1] < obs[11] + self.size_obstacle[1] / 2 + k) and (obs[1] > obs[11] - self.size_obstacle[1] / 2 - k)\
#              and (obs[2] < obs[12] + 0.03 + self.size_obstacle[2] / 2 + k) and (obs[2] > obs[12] + 0.03 - self.size_obstacle[2] / 2 - k)and ((d > self.distance_threshold).astype(np.float32) == 1))
#         # d1 = goal_distance(achieved_goal, goal[1])
#         if c:
#             cost = 1
#         else:
#             cost = 0
#         return np.array(cost)
    
#     def step(self, action):
#         action = np.clip(action, self.action_space.low, self.action_space.high)
#         self._set_action(action)
#         self.sim.step()
#         self._step_callback()
#         obs = self._get_obs()

#         done = False
#         info = {
#             'is_success': self._is_success(obs['achieved_goal'], self.goal),
#         }
#         reward = self.compute_reward(obs['achieved_goal'], self.goal, info)
#         cost = self.compute_cost(obs['achieved_goal'], self.goal, obs['observation'], self.cost_threshold, info)
#         return obs, reward, cost, done, info
    
#     def _get_obs(self):
#         # positions
#         grip_pos = self.sim.data.get_site_xpos('robot0:grip')
#         dt = self.sim.nsubsteps * self.sim.model.opt.timestep
#         grip_velp = self.sim.data.get_site_xvelp('robot0:grip') * dt
#         robot_qpos, robot_qvel = robot_utils.robot_get_obs(self.sim)
#         if self.has_object:
#             object_pos = self.sim.data.get_site_xpos('object0')
#             # rotations
#             object_rot = rotations.mat2euler(self.sim.data.get_site_xmat('object0'))
#             # velocities
#             object_velp = self.sim.data.get_site_xvelp('object0') * dt
#             object_velr = self.sim.data.get_site_xvelr('object0') * dt
#             # gripper state
#             object_rel_pos = object_pos - grip_pos
#             object_velp -= grip_velp
#         else:
#             object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
#         gripper_state = robot_qpos[-2:]
#         gripper_vel = robot_qvel[-2:] * dt  # change to a scalar if the gripper is made symmetric
        
#         # print(grip_pos)

#         if not self.has_object:
#             achieved_goal = grip_pos.copy()
#         else:
#             achieved_goal = np.squeeze(object_pos.copy())
#         obs = np.concatenate([
#             grip_pos, object_pos.ravel(), object_rel_pos.ravel(), gripper_state, object_rot.ravel(),
#             object_velp.ravel(), object_velr.ravel(), grip_velp, gripper_vel,
#         ])

#         obstacle_pos = self.sim.data.get_body_xpos('obstacle0')
#         obstacle_pos[2] = obstacle_pos[2] - 0.03
#         # print(1, obstacle_pos)
#         obstacle_rot = rotations.mat2euler(self.sim.data.get_body_xmat('obstacle0'))
#         # gripper state
#         obstacle_grip_rel_pos = obstacle_pos - grip_pos
#         # object state
#         obstacle_obj_rel_pos = np.zeros(0)
#         # obs = np.concatenate([obs, obstacle_pos.ravel(
#         # ), obstacle_grip_rel_pos.ravel(), obstacle_obj_rel_pos.ravel()])
#         obs = np.concatenate([obs, obstacle_pos.ravel(), obstacle_rot.ravel(), obstacle_grip_rel_pos.ravel(), obstacle_obj_rel_pos.ravel()])


#         return {
#             'observation': obs.copy(),
#             'achieved_goal': achieved_goal.copy(),
#             'desired_goal': self.goal.copy(),
#         }

###########################################

# '''
# For online evaluation: MazeRunner env with a fixed maze and goal across episodes
# '''
# class MazeRunnerEvalEnv(MazeRunnerEnv):
#     def __init__(self, maze_dim, min_num_goals, max_num_goals, maze, goal_pos):
#         self.env_name = 'mazerunner'
#         self.maze = deepcopy(maze)
#         self.goal_positions = deepcopy(goal_pos)
#         self.prompt_dim = 2 # goal: (x,y)
#         super().__init__(maze_dim, min_num_goals, max_num_goals)
#         self.discrete_action = True if hasattr(self.action_space, 'n') else False
    
#     def reset(self, *args, **kwargs):
#         self.start = (self.maze_dim - 2, self.maze_dim // 2)
#         empty_locations = [x for x in zip(*np.where(self.maze == 0))]
#         empty_locations.remove(self.start)
#         self.active_goal_idx = 0
#         self.pos = self.start
#         self._enforce_reset = False
#         self._plotting = False
#         self._goal_render_texts = [None for _ in range(len(self.goal_positions))]
#         self.episode_visited_goals = [] # save prompt goals in episode experience, for online adaptation
#         return self._get_obs()

#     def step(self, act):
#         obs, rew, terminated, truncated, info = super().step(act)
#         self.episode_visited_goals.append(np.asarray(obs[0:2])) # save prompt goals in episode experience, for online adaptation
#         return obs, rew, terminated, info

#     def render_with_prompt(self, prompt, return_rgb=False):
#         if not self._plotting:
#             self.start_plotting()
#             plt.ion()
#             self._plotting = True

#         plt.tight_layout()
#         background = np.ones((self.maze_dim, self.maze_dim, 3), dtype=np.uint8)
#         maze_img = (
#             background * abs(np.expand_dims(self.maze, -1) - 1) * 255
#         )  # zero out (white) where there is a valid path

#         prompt_color = [240, 3, 252]
#         for p in prompt:
#             maze_img[p[0], p[1], :] = prompt_color

#         maze_img[self.pos[0], self.pos[1], :] = [110, 110, 110]  # grey
#         plt.imshow(maze_img)

#         for i, p in enumerate(prompt):
#             plt.text(
#                 p[1], p[0], str(i), ha="center", va="center"
#             )

#         plt.draw()
#         plt.pause(0.1)

#         if return_rgb:
#             canvas = FigureCanvasAgg(plt.gcf())
#             canvas.draw()
#             w, h = canvas.get_width_height()
#             buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8)
#             buf.shape = (w,h,4)
#             buf = np.roll(buf, 3, axis=2)
#             image = Image.frombytes("RGBA", (w,h), buf.tostring())
#             image = np.asarray(image)[:,:,:3]
#             return image

#     def get_task_prompt(self):
#         return [np.array(p)/self.maze_dim for p in self.goal_positions]

# make envs for eval with the tasks in trajs
def get_env_list(trajs, device, max_ep_len=50, **kwargs):
    infos, env_list = [], []
    for traj in trajs:
        # maze = traj['maze']
        goal_pos = traj['goal_pos']
        if "bsa_box_size" in kwargs:
            env = FetchReachObstacleEnv(bsa_box_size = kwargs["bsa_box_size"])
        else:
            env = FetchReachObstacleEnv()
        env.goal = goal_pos
        env.prompt_dim = 3
        info = {'max_ep_len': max_ep_len, 'state_dim': traj["observations"].shape[1], #env.observation_space["observation"].shape[0], # env.observation_space.shape[0],
                'act_dim': env.action_space.shape[0], 'device': device, 'prompt_dim': env.prompt_dim, 'discrete_action': False}
        infos.append(info)
        env_list.append(env)
    return infos, env_list

# load training dataset, and envs of some training/test tasks
def get_train_test_dataset_envs(dataset_path, device, max_ep_len=50, n_train_env=50, n_test_env=50, **kwargs):
    fn = os.path.basename(dataset_path)
    #max_ep_len = int(fn.split('-')[-2][1:])
    #print(max_ep_len)

    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)
    test_trajectories_list = trajectories[0:n_test_env] # unseen tasks to eval
    trajectories_list = trajectories[n_test_env:] # offline data for training
    val_trajectories_list = trajectories_list[0:n_train_env] # seen tasks to eval
    
    info, env_list = get_env_list(val_trajectories_list, device, max_ep_len, **kwargs)
    test_info, test_env_list = get_env_list(test_trajectories_list, device, max_ep_len, **kwargs)
    #print(test_info, test_env_list, info, env_list, len(trajectories_list))
    return info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list


# given a trajectory, sample a goal-sequence as prompt
def get_prompt(trajectory, max_prompt_length=5, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        prompt_length = np.random.randint(1,max_prompt_length+1)
    
    goal_timesteps = []
    if use_optimal_prompt and 'optimal_prompts' in trajectory: # use the saved optimal prompt
        if prompt_length > trajectory['optimal_prompts'].shape[0]-1: # if prompt_length exceeds optimal prompts
            prompt_length = trajectory['optimal_prompts'].shape[0]-1
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['optimal_prompts'].shape[0]-1)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['optimal_prompts'].shape[0]-1)
    else: # sample (len-1) goals + the last goal in the trajectory as the prompt
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['timesteps']-1)
            prompt_length = min(prompt_length, trajectory['timesteps']-1) # prompt len cannot exceed traj len
            #print('pl:', prompt_length)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['timesteps']-1) # the last goal in the prompt is the task-goal
    #print(prompt_length, goal_timesteps)

    # padding to the left
    goal_timesteps = [0]*(max_prompt_length-prompt_length) + goal_timesteps # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    for t in goal_timesteps:
        if use_optimal_prompt and 'optimal_prompts' in trajectory:
            prompt.append(trajectory['optimal_prompts'][t])
        else:
            #print(trajectory['next_observations'][t][0:2]*15, trajectory['goal_pos'])
            prompt.append(trajectory['next_observations'][t][0:3])
    prompt = np.array(prompt).reshape(1,-1,3)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask


# compute the upperbound return for each task specified in the trajectory 
# (to compare with different methods). Input a list of trajectories, return a list 
# of return (float) of the task in each trajectory.

# def get_oracle_returns(trajectories, reward_step_penalty=-0.1):
#     ret = []
#     for t in trajectories:
#         l = len(t['optimal_prompts'])
#         ret.append(1+l*reward_step_penalty)
#     return ret


def get_avoid_prompt_old(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        if len(trajectory["avoid_states"].shape) == 2:
            prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
        else:
            prompt_length = 0


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    if prompt_length > 0:
        for t in goal_timesteps:
            prompt.append(trajectory["avoid_states"][t,:])
    else:
        for t in goal_timesteps:
            prompt.append(np.array([0, 0, 0]))
    prompt = np.array(prompt).reshape(1,-1,3)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask

def get_avoid_prompt_old2(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        prompt_length = np.random.randint(1,max_avoid_prompt_length+1)


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    observations = trajectory['observations']
    start_pos = observations[0][:3]
    end_pos = observations[-1][:3]

    prompt = []
    for t in goal_timesteps:
        random_point_along_axis = start_pos + (start_pos - end_pos) * np.random.rand()
        random_size = np.random.rand() * 0.1
        center_x, center_y, center_z = random_point_along_axis[:3]
        # corners = [[center_x + a, center_y + b, center_z + c] for a in displacements for b in displacements for c in displacements]
        avoid_box = np.array([center_x - random_size, center_y - random_size, center_z - random_size, center_x + random_size, center_y + random_size, center_z + random_size])
        prompt.append(avoid_box)
    
    prompt = np.array(prompt).reshape(1,-1,6)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask

def get_avoid_prompt(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        if len(trajectory["avoid_states"].shape) == 2:
            prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
        else:
            prompt_length = 0


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    if prompt_length > 0:
        for t in goal_timesteps:
            prompt.append(trajectory["avoid_states"][t,:])
    else:
        for t in goal_timesteps:
            prompt.append(np.array([0, 0, 0]))
    prompt = np.array(prompt).reshape(1,-1,6)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    success = trajectory["success"]
    #print(prompt, goal_timesteps, mask)
    return prompt, mask, success


if __name__=='__main__':
    '''
    with open('envs/mazerunner/mazerunner-d15-g1-astar.pkl', 'rb') as f:
        trajectories = pickle.load(f)

    main_keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals']
    info_keys = ['timesteps', 'maze', 'goal_pos']
    print(len(trajectories))

    for k in main_keys:
        print(k, type(trajectories[0][k]), trajectories[0][k].dtype, trajectories[0][k].shape)

    for k in info_keys:
        print(k, trajectories[0][k])
    '''

    # import torch
    # device=torch.device('cuda')
    # info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list = \
    #     get_train_test_dataset_envs('mazerunner-d30-g4-4-t500-multigoal-astar.pkl', device, max_ep_len=500)
    # #for i in range(100):
    # #    get_prompt(trajectories_list[0], device=device)
    # #variant={'K':50, 'batch_size': 16, 'max_prompt_len': 5}
    # #fn = get_prompt_batch(trajectories_list, info[0], variant)
    # #prompt, batch = fn()

    # for env, traj in zip(test_env_list, test_trajectories_list):
    #     env.reset()
    #     prompts = (traj['optimal_prompts']*env.maze_dim).astype(int)
    #     for p in prompts:
    #         env.pos = p            
    #         env.render()
    
