import numpy as np
import pickle
# from .mazerunner import MazeRunnerEnv
from copy import deepcopy
import os
import torch
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
# from PIL import Image


import os
# from gym import utils
# from gym.envs.robotics import rotations, robot_env, fetch_env
# import gym.envs.robotics.utils as robot_utils
from gymnasium import utils
from gymnasium_robotics.utils import rotations
from gymnasium_robotics.envs import robot_env
from gymnasium_robotics.envs.fetch import fetch_env
import gymnasium_robotics.utils.mujoco_py_utils as robot_utils
from raDT.constants import *


# Ensure we get the path separator correct on windows
MODEL_XML_PATH = os.path.join(ENV_PATH, 'assets', 'fetch', 'slide_obstacle.xml')

def goal_distance(goal_a, goal_b):
    assert goal_a.shape == goal_b.shape
    return np.linalg.norm(goal_a - goal_b, axis=-1)

class FetchSlideObstacleEnv(fetch_env.MujocoPyFetchEnv, utils.EzPickle):
    def __init__(self, reward_type='sparse'):
        initial_qpos = {
            'robot0:slide0': 0.05,
            'robot0:slide1': 0.48,
            'robot0:slide2': 0.0,
            'object0:joint': [1.7, 1.1, 0.41, 1., 0., 0., 0.],
            'obstacle0:joint': [1.7, 1, 0.41, 1., 0., 0., 0.]
        }
        fetch_env.MujocoPyFetchEnv.__init__(
            self, model_path = MODEL_XML_PATH, has_object=True, block_gripper=True, n_substeps=20,
            gripper_extra_height=-0.02, target_in_the_air=False, target_offset=np.array([0.4, 0.0, 0.0]),
            obj_range=0.1, target_range=0.3, distance_threshold=0.05,
            initial_qpos=initial_qpos, reward_type=reward_type)
        utils.EzPickle.__init__(self)

        self.pos_obstacle = self.sim.model.geom_pos[self.sim.model.geom_name2id('obstacle0')]
        self.size_object = self.sim.model.geom_size[self.sim.model.geom_name2id('object0')]
        self.size_obstacle = self.sim.model.geom_size[self.sim.model.geom_name2id('obstacle0')]
        self.cost_threshold = 0.05

    def _reset_sim(self):
        self.sim.set_state(self.initial_state)

        # Randomize start position of object.
        if self.has_object:
            object_xpos = self.initial_gripper_xpos[:2]
            
            while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
                object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range,self.obj_range, size=2)
            object_qpos = self.sim.data.get_joint_qpos('object0:joint')
            assert object_qpos.shape == (7,)
            object_qpos[:2] = object_xpos

            self.sim.data.set_joint_qpos('object0:joint', object_qpos)
            self.object_qpos = object_qpos

        obstacle_xpos = self.initial_gripper_xpos[:2]
        # print(2,obstacle_xpos)
        while np.linalg.norm(obstacle_xpos - self.initial_gripper_xpos[:2]) < 0.1 or (self.has_object and np.linalg.norm(object_xpos - obstacle_xpos) < 0.1) or \
            (obstacle_xpos[0] - object_xpos[0]) < 0.1 :
            obstacle_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-0.15 ,0.2, size=2)
        obstacle_qpos = self.sim.data.get_joint_qpos('obstacle0:joint')
        assert obstacle_qpos.shape == (7,)
        obstacle_qpos[:2] = obstacle_xpos
        obstacle_qpos[2] = self.height_offset
        self.sim.data.set_joint_qpos('obstacle0:joint', obstacle_qpos)

        self.sim.forward()
        return True
    
    def compute_cost(self, obs, k, info):
        # Compute distance between goal and the achieved goal.
        c = ((obs[3] < obs[25] + 0.06 + (self.size_obstacle[0]) / 2 + self.size_object[0] / 2 + k) and (obs[3] > obs[25] + 0.06 - (self.size_obstacle[0]) / 2 - self.size_object[0] / 2 - k)\
             and (obs[4] < obs[26] + self.size_obstacle[1] / 2 + self.size_object[1] / 2 + k) and (obs[4] > obs[26] - self.size_obstacle[1] / 2 - self.size_object[1] / 2 - k))
        # d1 = goal_distance(achieved_goal, goal[1])
        if c:
            cost = 1
        else:
            cost = 0
        return np.array(cost)
    
    def step(self, action):
        action = np.clip(action, self.action_space.low, self.action_space.high)
        self._set_action(action)
        self.sim.step()
        self._step_callback()
        obs = self._get_obs()

        done = False
        info = {
            'is_success': self._is_success(obs['achieved_goal'], self.goal),
        }
        reward = self.compute_reward(obs['achieved_goal'], self.goal, info)
        cost = self.compute_cost(obs['observation'], self.cost_threshold, info)
        return obs, reward, cost, done, info
    
    def _sample_goal(self):
        if self.has_object:
            object_xpos = self.sim.data.get_joint_qpos('object0:joint')[:2]
            obstacle_xpos = self.sim.data.get_joint_qpos('obstacle0:joint')[:2]
            # print(1,obstacle_xpos)
            obstacle_xpos_x = obstacle_xpos[0]
            obstacle_xpos_y = obstacle_xpos[1]
            object_xpos_x = object_xpos[0]
            object_xpos_y = object_xpos[1]
            vector_from_obstacle_to_object_x = (object_xpos_x - obstacle_xpos_x)
            vector_from_obstacle_to_object_y = (object_xpos_y - obstacle_xpos_y)
            # print(vector_from_obstacle_to_object_y)

            goal = self.initial_gripper_xpos[:3]
            while np.linalg.norm(goal[:2] - self.initial_gripper_xpos[:2]) < 0.1 or np.linalg.norm(goal[:2] - object_xpos) < 0.1 or \
                np.linalg.norm(obstacle_xpos - goal[:2]) < 0.2 or vector_from_obstacle_to_object_y*(goal[1] - obstacle_xpos_y)>0:
            # while vector_from_obstacle_to_object_x*(goal[0] - obstacle_xpos_x) > 0 or vector_from_obstacle_to_object_y*(goal[1] - obstacle_xpos_y)>0:
                goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
                goal += self.target_offset
            goal[2] = self.height_offset
            if self.target_in_the_air and self.np_random.uniform() < 0.5:
                goal[2] += self.np_random.uniform(0, 0.45)
        else:
            goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
        return goal.copy()
    
    def _get_obs(self):
        # positions
        grip_pos = self.sim.data.get_site_xpos('robot0:grip')
        dt = self.sim.nsubsteps * self.sim.model.opt.timestep
        grip_velp = self.sim.data.get_site_xvelp('robot0:grip') * dt
        robot_qpos, robot_qvel = robot_utils.robot_get_obs(self.sim)
        if self.has_object:
            object_pos = self.sim.data.get_site_xpos('object0')
            # rotations
            object_rot = rotations.mat2euler(self.sim.data.get_site_xmat('object0'))
            # velocities
            object_velp = self.sim.data.get_site_xvelp('object0') * dt
            object_velr = self.sim.data.get_site_xvelr('object0') * dt
            # gripper state
            object_rel_pos = object_pos - grip_pos
            object_velp -= grip_velp
        else:
            object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
        gripper_state = robot_qpos[-2:]
        gripper_vel = robot_qvel[-2:] * dt  # change to a scalar if the gripper is made symmetric

        if not self.has_object:
            achieved_goal = grip_pos.copy()
        else:
            achieved_goal = np.squeeze(object_pos.copy())
        obs = np.concatenate([
            grip_pos, object_pos.ravel(), object_rel_pos.ravel(), gripper_state, object_rot.ravel(),
            object_velp.ravel(), object_velr.ravel(), grip_velp, gripper_vel,
        ])

        obstacle_pos = self.sim.data.get_site_xpos('obstacle0')
        # print(1,obstacle_pos,2,object_pos)
        obstacle_rot = rotations.mat2euler(self.sim.data.get_site_xmat('obstacle0'))
        # gripper state
        obstacle_grip_rel_pos = obstacle_pos - grip_pos
        # object state
        obstacle_obj_rel_pos = obstacle_pos - object_pos
        # obs = np.concatenate([obs, obstacle_pos.ravel(
        # ), obstacle_grip_rel_pos.ravel(), obstacle_obj_rel_pos.ravel()])
        obs = np.concatenate([obs, obstacle_pos.ravel(), obstacle_rot.ravel(), obstacle_grip_rel_pos.ravel(), obstacle_obj_rel_pos.ravel()])


        return {
            'observation': obs.copy(),
            'achieved_goal': achieved_goal.copy(),
            'desired_goal': self.goal.copy(),
        }

# make envs for eval with the tasks in trajs
def get_env_list(trajs, device, max_ep_len=50):
    infos, env_list = [], []
    for traj in trajs:
        # maze = traj['maze']
        goal_pos = traj['goal_pos']
        env = FetchSlideObstacleEnv()
        env.goal = goal_pos
        env.prompt_dim = 3
        info = {'max_ep_len': max_ep_len, 'state_dim': traj["observations"].shape[1], #env.observation_space["observation"].shape[0], # env.observation_space.shape[0],
                'act_dim': env.action_space.shape[0], 'device': device, 'prompt_dim': env.prompt_dim, 'discrete_action': False}
        infos.append(info)
        env_list.append(env)
    return infos, env_list

# load training dataset, and envs of some training/test tasks
def get_train_test_dataset_envs(dataset_path, device, max_ep_len=50, n_train_env=50, n_test_env=50, **kwargs):
    fn = os.path.basename(dataset_path)
    #max_ep_len = int(fn.split('-')[-2][1:])
    #print(max_ep_len)

    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)
    test_trajectories_list = trajectories[0:n_test_env] # unseen tasks to eval
    trajectories_list = trajectories[n_test_env:] # offline data for training
    val_trajectories_list = trajectories_list[0:n_train_env] # seen tasks to eval
    
    info, env_list = get_env_list(val_trajectories_list, device, max_ep_len)
    test_info, test_env_list = get_env_list(test_trajectories_list, device, max_ep_len)
    #print(test_info, test_env_list, info, env_list, len(trajectories_list))
    return info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list


# given a trajectory, sample a goal-sequence as prompt
def get_prompt(trajectory, max_prompt_length=5, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        prompt_length = np.random.randint(1,max_prompt_length+1)
    
    goal_timesteps = []
    if use_optimal_prompt and 'optimal_prompts' in trajectory: # use the saved optimal prompt
        if prompt_length > trajectory['optimal_prompts'].shape[0]-1: # if prompt_length exceeds optimal prompts
            prompt_length = trajectory['optimal_prompts'].shape[0]-1
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['optimal_prompts'].shape[0]-1)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['optimal_prompts'].shape[0]-1)
    else: # sample (len-1) goals + the last goal in the trajectory as the prompt
        if prompt_length>1:
            goal_range = np.arange(1, trajectory['timesteps']-1)
            prompt_length = min(prompt_length, trajectory['timesteps']-1) # prompt len cannot exceed traj len
            #print('pl:', prompt_length)
            goal_timesteps = np.random.choice(goal_range, prompt_length-1, replace=False).tolist()
            goal_timesteps.sort()
        goal_timesteps.append(trajectory['timesteps']-1) # the last goal in the prompt is the task-goal
    #print(prompt_length, goal_timesteps)

    # padding to the left
    goal_timesteps = [0]*(max_prompt_length-prompt_length) + goal_timesteps # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    for t in goal_timesteps:
        if use_optimal_prompt and 'optimal_prompts' in trajectory:
            prompt.append(trajectory['optimal_prompts'][t])
        else:
            #print(trajectory['next_observations'][t][0:2]*15, trajectory['goal_pos'])
            prompt.append(trajectory['next_observations'][t][3:6])
    prompt = np.array(prompt).reshape(1,-1,3)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(prompt, goal_timesteps, mask)
    return prompt, mask


# compute the upperbound return for each task specified in the trajectory 
# (to compare with different methods). Input a list of trajectories, return a list 
# of return (float) of the task in each trajectory.

# def get_oracle_returns(trajectories, reward_step_penalty=-0.1):
#     ret = []
#     for t in trajectories:
#         l = len(t['optimal_prompts'])
#         ret.append(1+l*reward_step_penalty)
#     return ret


# def get_avoid_prompt_old(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
#     if prompt_length is None:
#         # sample a prompt length between [1,max_prompt_length] for training
#         if len(trajectory["avoid_states"].shape) == 2:
#             prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
#         else:
#             prompt_length = 0


#     # padding to the left
#     goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
#     mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
#     #print(goal_timesteps, mask)

#     prompt = []
#     if prompt_length > 0:
#         for t in goal_timesteps:
#             prompt.append(trajectory["avoid_states"][t,:])
#     else:
#         for t in goal_timesteps:
#             prompt.append(np.array([0, 0, 0]))
#     prompt = np.array(prompt).reshape(1,-1,3)

#     prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
#     mask = torch.from_numpy(mask).to(device=device)
#     #print(prompt, goal_timesteps, mask)
#     return prompt, mask

# def get_avoid_prompt_old2(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
#     if prompt_length is None:
#         # sample a prompt length between [1,max_prompt_length] for training
#         prompt_length = np.random.randint(1,max_avoid_prompt_length+1)


#     # padding to the left
#     goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
#     mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
#     #print(goal_timesteps, mask)

#     observations = trajectory['observations']
#     start_pos = observations[0][3:6]
#     end_pos = observations[-1][3:6]

#     prompt = []
#     for t in goal_timesteps:
#         random_point_along_axis = start_pos + (start_pos - end_pos) * np.random.rand()
#         random_size = np.random.rand() * 0.1
#         center_x, center_y, center_z = random_point_along_axis[:3]
#         # corners = [[center_x + a, center_y + b, center_z + c] for a in displacements for b in displacements for c in displacements]
#         avoid_box = np.array([center_x - random_size, center_y - random_size, center_z - random_size, center_x + random_size, center_y + random_size, center_z + random_size])
#         prompt.append(avoid_box)
    
#     prompt = np.array(prompt).reshape(1,-1,6)

#     prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
#     mask = torch.from_numpy(mask).to(device=device)
#     #print(prompt, goal_timesteps, mask)
#     return prompt, mask

def get_avoid_prompt(trajectory, max_avoid_prompt_length=10, prompt_length=None, device=None, use_optimal_prompt=False):
    if prompt_length is None:
        # sample a prompt length between [1,max_prompt_length] for training
        if len(trajectory["avoid_states"].shape) == 2:
            prompt_length = min(trajectory["avoid_states"].shape[0], max_avoid_prompt_length)
        else:
            prompt_length = 0


    # padding to the left
    goal_timesteps = [0]*(max_avoid_prompt_length-prompt_length) + list(np.arange(prompt_length)) # pad with the initial position
    mask = np.concatenate([np.zeros((1, max_avoid_prompt_length - prompt_length)), np.ones((1, prompt_length))], axis=1)
    #print(goal_timesteps, mask)

    prompt = []
    if prompt_length > 0:
        for t in goal_timesteps:
            prompt.append(trajectory["avoid_states"][t,:])
    else:
        for t in goal_timesteps:
            prompt.append(np.array([0, 0, 0]))
    prompt = np.array(prompt).reshape(1,-1,6)

    prompt = torch.from_numpy(prompt).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    success = trajectory["success"]
    #print(prompt, goal_timesteps, mask)
    return prompt, mask, success


if __name__=='__main__':
    '''
    with open('envs/mazerunner/mazerunner-d15-g1-astar.pkl', 'rb') as f:
        trajectories = pickle.load(f)

    main_keys = ['observations', 'next_observations', 'actions', 'rewards', 'terminals']
    info_keys = ['timesteps', 'maze', 'goal_pos']
    print(len(trajectories))

    for k in main_keys:
        print(k, type(trajectories[0][k]), trajectories[0][k].dtype, trajectories[0][k].shape)

    for k in info_keys:
        print(k, trajectories[0][k])
    '''

    # import torch
    # device=torch.device('cuda')
    # info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list = \
    #     get_train_test_dataset_envs('mazerunner-d30-g4-4-t500-multigoal-astar.pkl', device, max_ep_len=500)
    # #for i in range(100):
    # #    get_prompt(trajectories_list[0], device=device)
    # #variant={'K':50, 'batch_size': 16, 'max_prompt_len': 5}
    # #fn = get_prompt_batch(trajectories_list, info[0], variant)
    # #prompt, batch = fn()

    # for env, traj in zip(test_env_list, test_trajectories_list):
    #     env.reset()
    #     prompts = (traj['optimal_prompts']*env.maze_dim).astype(int)
    #     for p in prompts:
    #         env.pos = p            
    #         env.render()
    
