from collections import OrderedDict

import time
import os
import yaml
from pathlib import Path
from copy import deepcopy

import isaacgym

import numpy as np
import torch

import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip

from gym import spaces
from gym.core import GoalEnv

from stable_baselines3.common.vec_env import VecEnv

from isaac_panda_push_env import IsaacPandaPush
from utils import load_pretrained_rep_model, check_config, get_camera_ray, get_dlp_rep

class SB3VecEnvAdapter(VecEnv):

    def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space):
        super().__init__(num_envs, observation_space, action_space)

    def step_async(self, actions):
        pass

    def step_wait(self):
        pass

    def get_attr(self, attr_name, indices=None):
        pass

    def set_attr(self, attr_name, value, indices=None):
        pass

    def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
        pass

    def seed(self, seed):
        pass

    def env_is_wrapped(self):
        pass

    def render(self):
        pass


class IsaacPandaPushGoalSB3Wrapper(GoalEnv, SB3VecEnvAdapter):
    def __init__(self, env, obs_mode, n_views, latent_rep_model, reward_cfg, **kwargs):

        self.env = env
        self.device = self.env.device

        super().__init__(self.env.num_envs, self.env.observation_space, self.env.action_space)

        # Gym specific attributes
        self.name = "PandaPush"
        self.spec = None
        self.metadata = None

        # observation related attributes
        self.obs_mode = obs_mode
        self.n_views = n_views

        # reward related attributes
        self.reward_scale = reward_cfg.get("reward_scale", 1.0)
        self.dist_threshold = reward_cfg.get("dist_threshold", np.sqrt(2) * self.env.cube_size)
        self.ori_threshold = reward_cfg.get("ori_threshold", 0.3)
        self.only_ori_reward = reward_cfg.get("only_ori_reward", False)
        self.reward_range = [-1 * self.reward_scale, 0 * self.reward_scale]

        # other attributes
        self.horizon = self.env.max_episode_length

        # representation model
        self.latent_rep_model = latent_rep_model.to(self.device) if latent_rep_model is not None else None

        # set up observation space
        obs_dict = self.env.reset()

        if self.obs_mode == 'state':
            obs = self._get_state_obs(obs_dict)
            obs_shape = obs.shape[1:]
            obs_low = -np.inf
            obs_high = np.inf
            obs_dtype = np.float32

            goal_shape = obs_shape
            goal_low = obs_low
            goal_high = obs_high
            goal_dtype = np.float32

            a_goal_shape = goal_shape
            d_goal_shape = goal_shape

        elif self.obs_mode == 'state_unstruct':
            obs = self._get_state_unstruct_obs(obs_dict)
            obs_shape = obs.shape[1:]
            obs_low = -np.inf
            obs_high = np.inf
            obs_dtype = np.float32

            goal_shape = obs_shape
            goal_low = obs_low
            goal_high = obs_high
            goal_dtype = np.float32

            a_goal_shape = goal_shape
            d_goal_shape = goal_shape

        elif self.obs_mode in ['dlp']:
            obs = self._get_state_obs(obs_dict)
            obs_shape = obs.shape[1:]
            obs_low = -np.inf
            obs_high = np.inf
            obs_dtype = np.float32

            goal = self._get_latent_obs(obs_dict)
            goal_shape = goal.shape[1:]
            goal_low = -np.inf
            goal_high = np.inf
            goal_dtype = np.float32

            a_goal_shape = goal_shape
            d_goal_shape = goal_shape

        elif self.obs_mode in ['3d_slot', '3d_block']:
            obs = self._get_state_obs(obs_dict)
            obs_shape = obs.shape[1:]
            obs_low = -np.inf
            obs_high = np.inf
            obs_dtype = np.float32

            goal = self._get_mv_latent_obs(obs_dict)
            goal_shape = goal.shape[1:]
            goal_low = -np.inf
            goal_high = np.inf
            goal_dtype = np.float32

            a_goal_shape = goal_shape
            d_goal_shape = goal_shape

        else:  # obs_mode == "raw"
            obs = self._get_image_obs(obs_dict)
            obs_shape = obs.shape[1:]
            obs_low = 0
            obs_high = 255
            obs_dtype = np.uint8

            goal_shape = obs_shape
            goal_low = 0
            goal_high = 255
            goal_dtype = np.uint8

            a_goal_shape = goal_shape
            d_goal_shape = goal_shape

        self.observation_space = spaces.Dict({
            # "observation": spaces.Box(low=obs_low, high=obs_high, shape=obs_shape, dtype=obs_dtype),  # commented out to accelerate code
            "desired_goal": spaces.Box(low=goal_low, high=goal_high, shape=d_goal_shape, dtype=goal_dtype),
            "achieved_goal": spaces.Box(low=goal_low, high=goal_high, shape=a_goal_shape, dtype=goal_dtype),
        })

        # set up goal
        self.goal = None
        self.goal_pos = {}
        self.goal_image = None

        # set up action space
        low, high = self.env.act_space.low, self.env.act_space.high
        low, high = low[:3], high[:3]  # for allowing vertical movements and closed gripper only
        self.action_space = spaces.Box(low=low, high=high)

    def reset(self):
        """
        Extends env reset method to return Goal Environment observation instead of normal OrderedDict.

        Returns:
            dict: GoalEnv observation after reset occurs
        """
        goal_obs_dict = self.get_random_goal()  # resets env

        obs_dict = self.env.reset()

        # extract observation and achieved goal
        if self.obs_mode == 'state':
            observation = self._get_state_obs(obs_dict)
            achieved_goal = observation
            self.goal = self._get_state_obs(goal_obs_dict)

        elif self.obs_mode == 'state_unstruct':
            observation = self._get_state_unstruct_obs(obs_dict)
            achieved_goal = observation
            self.goal = self._get_state_unstruct_obs(goal_obs_dict)

        elif self.obs_mode == 'dlp':
            observation = self._get_state_obs(obs_dict)
            achieved_goal = self._get_latent_obs(obs_dict)  # [n_views, *(latent_dims)]
            self.goal = self._get_latent_obs(goal_obs_dict)  # [n_views, *(latent_dims)]
                
        elif self.obs_mode in ['3d_slot', '3d_block']:
            observation = self._get_state_obs(obs_dict)
            achieved_goal = self._get_mv_latent_obs(obs_dict)  # [*(latent_dims)]
            self.goal = self._get_mv_latent_obs(goal_obs_dict)  # [*(latent_dims)]

        else:  # obs_mode == 'raw'
            observation = self._get_image_obs(obs_dict)  # [n_views, 3, h, w]
            achieved_goal = observation
            self.goal = self._get_image_obs(goal_obs_dict)  # [n_views, 3, h, w]

        # set goal info
        goal_observation = self._get_state_obs(goal_obs_dict)
        self.goal_pos = goal_observation[:, 1:, :-(self.num_objects+1)]
        self.goal_image = self._get_image_obs(goal_obs_dict)

        # create GoalEnv observation
        obs = {
            # "observation": observation,  # commented out to accelerate code
            "desired_goal": self.goal,
            "achieved_goal": achieved_goal
        }

        return obs

    def step(self, action):
        """
        Extends env step() function call to:
            - return goal environment observation instead of normal observation
            - compute reward based on goal and current state

        Args:
            action (torch.tensor): action to take in environment

        Returns:
            4-tuple:
                - observations based on obs_mode
                - reward from the environment
                - whether the current episode is completed or not
                - misc information
        """
        # modify action to fit env action space and allow vertical movements with closed gripper only
        action_xyz = torch.tensor(action, device=self.device, dtype=torch.float32)
        action_rest = torch.tensor([0, 0, 0, -1], device=self.device).unsqueeze(0).expand(self.num_envs, -1)
        action = torch.cat([action_xyz, action_rest], dim=-1)

        # take policy step
        obs_dict, _, episode_done, info = self.env.step(action)

        # extract observation
        if self.obs_mode == 'state':
            observation = self._get_state_obs(obs_dict)
            achieved_goal = observation

        elif self.obs_mode == 'state_unstruct':
            observation = self._get_state_unstruct_obs(obs_dict)
            achieved_goal = observation

        elif self.obs_mode == 'dlp':
            observation = self._get_state_obs(obs_dict)
            achieved_goal = self._get_latent_obs(obs_dict)

        elif self.obs_mode in ['3d_slot', '3d_block']:
            observation = self._get_state_obs(obs_dict)
            achieved_goal = self._get_mv_latent_obs(obs_dict)

        else:  # obs_mode == 'raw'
            observation = self._get_image_obs(obs_dict)
            achieved_goal = observation

        # create GoalEnv observation
        obs = {
            # "observation": observation,  # commented out to accelerate code
            "desired_goal": self.goal,
            "achieved_goal": achieved_goal
        }

        # save info
        vec_info = {
            "position": self._get_state_obs(obs_dict)[:, 1:, :-(self.num_objects+1)],
            "image": self._get_image_obs(obs_dict),
            "goal_pos": self.goal_pos,
            "goal_image": self.goal_image,
        }

        if self.obs_mode in ['3d_slot', '3d_block']:
            vec_info["ext"] = self._get_extra_obs("ext", obs_dict)
            vec_info["int"] = self._get_extra_obs("int", obs_dict)

        # get reward
        reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], vec_info)

        # add goal reaching info
        goal_frac_reached, avg_obj_dist, max_obj_dist, ori_dist = self.check_success(obs["achieved_goal"], obs["desired_goal"], vec_info)
        vec_info["goal_success_frac"] = goal_frac_reached
        vec_info["avg_obj_dist"] = avg_obj_dist
        vec_info["max_obj_dist"] = max_obj_dist

        # set done flag
        done = episode_done.cpu().numpy()  # shouldn't get done signal even if reached goal

        # add info for HerReplayBuffer use (ignoring done due to episode termination)
        vec_info["TimeLimit.truncated"] = episode_done.cpu().numpy()

        # convert info to tuple of dicts for SB3 compatibility
        info = tuple([{key: vec_info[key][i] for key in vec_info} for i in range(self.num_envs)])

        return obs, reward, done, info

    def compute_reward(self, achieved_goal, desired_goal, info={}):
        """
        Reward function for the goal conditioned task: negative distance from goal averaged over objects

        Args:
            achieved_goal: current state representation
            desired_goal:  goal state representation
            info: contains additional information

        Returns:
            goal conditioned reward
        """
        if type(info) == dict:
            a_goal, d_goal = info['position'].copy(), info['goal_pos'].copy()

        else:  # numpy array of dicts from HER replay buffer
            a_goal = np.array([info[i]['position'] for i in range(len(info))])
            d_goal = np.array([info[i]['goal_pos'] for i in range(len(info))])

        # normalize xy positions by table scale
        table_diag_len = (np.linalg.norm([(self.table_dims[0]) / 2, (self.table_dims[1]) / 2]))
        a_goal[..., :2] /= table_diag_len
        d_goal[..., :2] /= table_diag_len
        # calculate per object distance
        dist = np.linalg.norm(a_goal - d_goal, axis=-1)
        # calculate reward
        reward = -np.mean(dist, axis=-1)
        # rescale reward
        reward = reward * self.reward_scale

        return reward

    def check_success(self, achieved_goal, desired_goal, info={}):
        """
        Checks goal reaching success
        Args:
            achieved_goal: current state representation
            desired_goal:  goal state representation
            info: contains additional information

        Returns:
            fraction of goals reached
        """
        a_goal, d_goal = info['position'], info['goal_pos']

        ori_dist = None
        dist = np.linalg.norm(a_goal[..., :2] - d_goal[..., :2], axis=-1)
        obj_goal_reached = dist < self.dist_threshold

        goal_frac_reached = np.mean(obj_goal_reached, axis=-1)
        avg_obj_dist = np.mean(dist, axis=-1)
        max_obj_dist = np.max(dist, axis=-1)

        return goal_frac_reached, avg_obj_dist, max_obj_dist, ori_dist

    def get_random_goal(self):
        # pre reset
        self.env.goal_reset = True
        # reset
        self.env.reset()
        # post reset
        self.env.goal_reset = False
        # move arm one step back
        action = torch.tensor([-1, 0, 0, 0, 0, 0, -1], device=self.device).unsqueeze(0).expand(self.env.num_envs, -1)
        self.env.step(action)
        # get goal
        goal_obs_dict, _, _, _ = self.env.step(action)
        goal_obs_dict = deepcopy(goal_obs_dict)

        return goal_obs_dict

    def _get_state_obs(self, obs_dict):
        """
           Gets simulation state from environment, reshapes to add an entity dimension
           and concatenates 1-hot features to each entity (eef + objects)

           Args:
               obs_dict (OrderedDict): ordered dictionary of observations

           Returns:
               np.ndarray: [num_envs, num_entities, state_dim + 1-hot_identifier]
       """
        num_envs = self.num_envs
        num_entities = self.num_objects + 1

        obs = obs_dict["obs"].reshape(num_envs, num_entities, -1)
        one_hot_id = torch.eye(num_entities, device=self.device).unsqueeze(0).expand(num_envs, -1, -1)
        obs = torch.cat([obs[..., :2], one_hot_id], dim=-1)

        return obs.cpu().numpy().squeeze()

    def _get_state_unstruct_obs(self, obs_dict):
        """
           Gets simulation state from environment, concatenates states from all entities

           Args:
               obs_dict (OrderedDict): ordered dictionary of observations

           Returns:
               np.ndarray: [num_envs, num_entities * state_dim]
       """
        obs = obs_dict["obs"][..., :2].reshape(self.num_envs, -1)
        return obs.cpu().numpy().squeeze()

    def _get_image_obs(self, obs_dict):
        """
        Gets multiview image observations

        Args:
            obs_dict (OrderedDict): ordered dictionary of observations

        Returns:
            np.array: [num_envs, num_views, channels, height, width]
        """
        obs = obs_dict["media"][:, :self.n_views]
        return obs.cpu().numpy()

    def _get_extra_obs(self, extra, obs_dict):
        """
        Gets multiview image observations

        Args:
            obs_dict (OrderedDict): ordered dictionary of observations

        Returns:
            np.array: [num_envs, num_views, channels, height, width]
        """
        obs = obs_dict[extra][:, :self.n_views]
        return obs.cpu().numpy()

    def _get_latent_obs(self, obs_dict):
        """
        Gets multiview latent representations

        Args:
            obs_dict (OrderedDict): ordered dictionary of observations

        Returns:
            np.array: [num_envs, num_views, num_entities, feature_dim]
        """
        image_obs = obs_dict["media"][:, :self.n_views]
        obs = self._image_to_latent_rep(image_obs)
        return obs.cpu().numpy()

    def _get_mv_latent_obs(self, obs_dict):
        """
        Gets multiview latent representations

        Args:
            obs_dict (OrderedDict): ordered dictionary of observations

        Returns:
            np.array: [num_envs, num_entities, feature_dim]
        """
        image_obs = obs_dict["media"][:, :self.n_views]
        ext_obs = obs_dict["ext"][:, :self.n_views]
        int_obs = obs_dict["int"][:, :self.n_views]

        ray_list = []
        cam_pos_list= []
        for b in range(self.num_envs):
            per_view_ray = []
            per_view_cam_pos = []

            for v in range(self.n_views):
                ext = ext_obs[b, v].cpu().numpy()     # [3, 4]
                int = int_obs[b, v].cpu().numpy()     # [3, 3]

                R = ext[:3, :3]
                t = ext[:3, 3]
                cam_pos_world = -R.T @ t

                env_idx = b
                row = env_idx // 4
                col = env_idx % 4
                offset = np.array([3.0 * col, 3.0 * row, 0.0])

                cam_pos_local = cam_pos_world - offset
                corrected_t = -R @ cam_pos_local
                ext_corrected = ext.copy()
                ext_corrected[:3, 3] = corrected_t

                cam_pos, ray = get_camera_ray(ext_corrected, int, 128, 128)

                per_view_ray.append(torch.tensor(ray, device=image_obs.device))
                per_view_cam_pos.append(torch.tensor(cam_pos, device=image_obs.device))

            ray_list.append(torch.stack(per_view_ray, dim=0))       # [n_views, H*W, 3]
            cam_pos_list.append(torch.stack(per_view_cam_pos, dim=0))  # [n_views, 3]

        ray = torch.stack(ray_list, dim=0)
        cam_pos = torch.stack(cam_pos_list, dim=0)
        obs = self._image_to_latent_rep(image_obs, cam_pos, ray)

        if self.obs_mode == '3d_slot':
            return obs.cpu().numpy()

        return obs[..., 1:, :].cpu().numpy()

    def _image_to_latent_rep(self, image_obs, cam_pos=None, ray=None):
        orig_obs_shape = image_obs.shape

        if len(orig_obs_shape) == 4 and self.obs_mode != '3d_block' and self.obs_mode != '3d_slot':  # no batch dim
            image_obs = image_obs.unsqueeze(0)

        if self.obs_mode == 'dlp':
            latent_obs = [self._extract_dlp_features(image_obs[:, i]) for i in range(self.n_views)]
            latent_obs = torch.cat([view.unsqueeze(1) for view in latent_obs], dim=1)
        elif self.obs_mode == '3d_slot':
            latent_obs = [self._extract_slot_features(image_obs[:, i]) for i in range(self.n_views)]
        elif self.obs_mode == '3d_block':
            latent_obs = self._extract_slot_features(image_obs, cam_pos, ray)
        else:
            raise NotImplementedError

        if len(orig_obs_shape) == 4:  # no batch dim
            latent_obs = latent_obs.squeeze(0)
            
        return latent_obs

    def _extract_dlp_features(self, image):
        normalized_image = image.to(torch.float32) / 255

        with torch.no_grad():
            encoded_output = self.latent_rep_model.encode_all(normalized_image, deterministic=True)
            dlp_features = get_dlp_rep(encoded_output)

        return dlp_features

    def _extract_slot_features(self, image, cam_pos, ray):
        normalized_image = image.to(torch.float32) / 255

        with torch.no_grad():
            slots = self.latent_rep_model(normalized_image, cam_pos, ray)

        return slots
        
    def _get_ori_aware_goal(self, goal):
        pos, ori = goal[..., :2], goal[..., 2:]
        pos_list = []
        # for i in range(self.num_objects):
        pos_list.append(pos)
        pos_list.append(np.concatenate([pos[..., 0:1] + 4 * 0.03 * np.cos(ori),
                                        pos[..., 1:] + 4 * 0.03 * np.sin(ori)], axis=-1))
        pos_list.append(np.concatenate([pos[..., 0:1] + 4 * 0.03 * np.cos(ori + np.pi/2),
                                        pos[..., 1:] + 4 * 0.03 * np.sin(ori + np.pi/2)], axis=-1))

        ori_aware_goal = np.concatenate(pos_list, axis=-2)
        return ori_aware_goal

    def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
        if method_name == "compute_reward":
            return self.compute_reward(*method_args)
        else:
            raise NotImplementedError(f"Method {method_name} is not implemented in this env")

    @property
    def num_objects(self) -> int:
        """Get the (maximum) number of objects in the environment."""
        return self.env.num_objects

    @property
    def cur_num_objects(self) -> int:
        """Get the current number of objects in the environment."""
        return self.env.cur_num_objects

    @property
    def num_colors(self) -> int:
        """Get the (maximum) number of objects in the environment."""
        return self.env.num_colors

    @property
    def max_episode_len(self) -> int:
        """Get the number of objects in the environment."""
        return self.env.max_episode_length

    @property
    def table_dims(self) -> list:
        """Table dimensions"""
        return self.env.table_dims