# Copyright (c) 2018-2021, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#	list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#	this list of conditions and the following disclaimer in the documentation
#	and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#	contributors may be used to endorse or promote products derived from
#	this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Dict, Any, Tuple

import gym
from gym import spaces

from isaacgym import gymtorch, gymapi
from isaacgym.torch_utils import to_torch
from isaacgym.gymutil import get_property_setter_map, get_property_getter_map, get_default_setter_args, apply_random_samples, check_buckets, generate_random_samples

import torch
import numpy as np
import operator, random
from copy import deepcopy

import sys

import abc
from abc import ABC



class Env(ABC):
	def __init__(self, config: Dict[str, Any], sim_device: str, graphics_device_id: int,  headless: bool):
		"""Initialise the env.

		Args:
			config: the configuration dictionary.
			sim_device: the device to simulate physics on. eg. 'cuda:0' or 'cpu'
			graphics_device_id: the device ID to render with.
			headless: Set to False to disable viewer rendering.
		"""

		self.cfg = config
		split_device = sim_device.split(":")
		self.device_type = split_device[0]
		self.device_id = int(split_device[1]) if len(split_device) > 1 else 0

		self.device = "cpu"
		if config["sim"]["use_gpu_pipeline"]:
			if self.device_type.lower() == "cuda" or self.device_type.lower() == "gpu":
				self.device = "cuda" + ":" + str(self.device_id)
			else:
				print("GPU Pipeline can only be used with GPU simulation. Forcing CPU Pipeline.")
				config["sim"]["use_gpu_pipeline"] = False

		self.rl_device = config.get("rl_device", "cuda:0")

		# Rendering
		# if training in a headless mode
		self.headless = headless

		enable_camera_sensors = config.get("enableCameraSensors", False)
		self.graphics_device_id = graphics_device_id
		if enable_camera_sensors == False and self.headless == True:
			self.graphics_device_id = -1

		self.num_environments = config["env"]["numEnvs"]
		self.control_freq_inv = config["env"].get("controlFrequencyInv", 1)

class VecTask(Env):

	def __init__(self, config, sim_device, graphics_device_id, headless):
		"""Initialise the `VecTask`.

		Args:
			config: config dictionary for the environment.
			sim_device: the device to simulate physics on. eg. 'cuda:0' or 'cpu'
			graphics_device_id: the device ID to render with.
			headless: Set to False to disable viewer rendering.
		"""
		super().__init__(config, sim_device, graphics_device_id, headless)

		self.sim_params = self.__parse_sim_params(self.cfg["physics_engine"], self.cfg["sim"])
		if self.cfg["physics_engine"] == "physx":
			self.physics_engine = gymapi.SIM_PHYSX
		elif self.cfg["physics_engine"] == "flex":
			self.physics_engine = gymapi.SIM_FLEX
		else:
			msg = f"Invalid physics engine backend: {self.cfg['physics_engine']}"
			raise ValueError(msg)

		# optimization flags for pytorch JIT
		torch._C._jit_set_profiling_mode(False)
		torch._C._jit_set_profiling_executor(False)

		self.gym = gymapi.acquire_gym()

		self.first_randomization = True
		self.original_props = {}
		self.dr_randomizations = {}
		self.actor_params_generator = None
		self.extern_actor_params = {}
		self.last_step = -1
		self.last_rand_step = -1

	def set_viewer(self):
		"""Create the viewer."""

		# todo: read from config
		self.enable_viewer_sync = True
		self.viewer = None

		# if running with a viewer, set up keyboard shortcuts and camera
		if self.headless == False:
			# subscribe to keyboard shortcuts
			self.viewer = self.gym.create_viewer(
				self.sim, gymapi.CameraProperties()
			)
			self.gym.subscribe_viewer_keyboard_event(
				self.viewer, gymapi.KEY_ESCAPE, "QUIT"
			)
			self.gym.subscribe_viewer_keyboard_event(
				self.viewer, gymapi.KEY_V, "toggle_viewer_sync"
			)

			# set the camera position based on up axis
			sim_params = self.gym.get_sim_params(self.sim)
			if sim_params.up_axis == gymapi.UP_AXIS_Z:
				cam_pos = gymapi.Vec3(-3.0, -3.0, 2.0)
				cam_target = gymapi.Vec3(0.0, 0.0, 0.0)
			else:
				cam_pos = gymapi.Vec3(-3.0, -3.0, 2.0)
				cam_target = gymapi.Vec3(0.0, 0.0, 0.0)

			self.gym.viewer_camera_look_at(
				self.viewer, None, cam_pos, cam_target)

	def allocate_buffers(self):
		"""Allocate the observation, states, etc. buffers.

		These are what is used to set observations and states in the environment classes which
		inherit from this one, and are read in `step` and other related functions.

		"""

	#
	def set_sim_params_up_axis(self, sim_params: gymapi.SimParams, axis: str) -> int:
		"""Set gravity based on up axis and return axis index.

		Args:
			sim_params: sim params to modify the axis for.
			axis: axis to set sim params for.
		Returns:
			axis index for up axis.
		"""
		if axis == 'z':
			sim_params.up_axis = gymapi.UP_AXIS_Z
			sim_params.gravity.x = 0
			sim_params.gravity.y = 0
			sim_params.gravity.z = -9.81
			return 2
		return 1

	def create_sim(self, compute_device: int, graphics_device: int, physics_engine, sim_params: gymapi.SimParams):
		"""Create an Isaac Gym sim object.

		Args:
			compute_device: ID of compute device to use.
			graphics_device: ID of graphics device to use.
			physics_engine: physics engine to use (`gymapi.SIM_PHYSX` or `gymapi.SIM_FLEX`)
			sim_params: sim params to use.
		Returns:
			the Isaac Gym sim object.
		"""
		sim = self.gym.create_sim(compute_device, graphics_device, physics_engine, sim_params)
		if sim is None:
			print("*** Failed to create sim")
			quit()

		return sim

	def get_state(self):
		"""Returns the state buffer of the environment (the priviledged observations for asymmetric training)."""
		return torch.clamp(self.states_buf, -self.clip_obs, self.clip_obs).to(self.rl_device)

	def render(self):
		"""Draw the frame to the viewer, and check for keyboard events."""
		if self.viewer:
			# check for window closed
			if self.gym.query_viewer_has_closed(self.viewer):
				sys.exit()

			# check for keyboard events
			for evt in self.gym.query_viewer_action_events(self.viewer):
				if evt.action == "QUIT" and evt.value > 0:
					sys.exit()
				elif evt.action == "toggle_viewer_sync" and evt.value > 0:
					self.enable_viewer_sync = not self.enable_viewer_sync

			# fetch results
			if self.device != 'cpu':
				self.gym.fetch_results(self.sim, True)

			# step graphics
			if self.enable_viewer_sync:
				self.gym.step_graphics(self.sim)
				self.gym.draw_viewer(self.viewer, self.sim, True)

				# Wait for dt to elapse in real time.
				# This synchronizes the physics simulation with the rendering rate.
				self.gym.sync_frame_time(self.sim)

			else:
				self.gym.poll_viewer_events(self.viewer)

	def __parse_sim_params(self, physics_engine: str, config_sim: Dict[str, Any]) -> gymapi.SimParams:
		"""Parse the config dictionary for physics stepping settings.

		Args:
			physics_engine: which physics engine to use. "physx" or "flex"
			config_sim: dict of sim configuration parameters
		Returns
			IsaacGym SimParams object with updated settings.
		"""
		sim_params = gymapi.SimParams()

		# check correct up-axis
		if config_sim["up_axis"] not in ["z", "y"]:
			msg = f"Invalid physics up-axis: {config_sim['up_axis']}"
			print(msg)
			raise ValueError(msg)

		# assign general sim parameters
		sim_params.dt = config_sim["dt"]
		sim_params.num_client_threads = config_sim.get("num_client_threads", 0)
		sim_params.use_gpu_pipeline = config_sim["use_gpu_pipeline"]
		sim_params.substeps = config_sim.get("substeps", 2)

		# assign up-axis
		if config_sim["up_axis"] == "z":
			sim_params.up_axis = gymapi.UP_AXIS_Z
		else:
			sim_params.up_axis = gymapi.UP_AXIS_Y

		# assign gravity
		sim_params.gravity = gymapi.Vec3(*config_sim["gravity"])

		# configure physics parameters
		if physics_engine == "physx":
			# set the parameters
			if "physx" in config_sim:
				for opt in config_sim["physx"].keys():
					if opt == "contact_collection":
						setattr(sim_params.physx, opt, gymapi.ContactCollection(config_sim["physx"][opt]))
					else:
						setattr(sim_params.physx, opt, config_sim["physx"][opt])
		else:
			# set the parameters
			if "flex" in config_sim:
				for opt in config_sim["flex"].keys():
					setattr(sim_params.flex, opt, config_sim["flex"][opt])

		# return the configured params
		return sim_params

	"""
	Domain Randomization methods
	"""

	def get_actor_params_info(self, dr_params: Dict[str, Any], env):
		"""Generate a flat array of actor params, their names and ranges.

		Returns:
			The array
		"""

		if "actor_params" not in dr_params:
			return None
		params = []
		names = []
		lows = []
		highs = []
		param_getters_map = get_property_getter_map(self.gym)
		for actor, actor_properties in dr_params["actor_params"].items():
			handle = self.gym.find_actor_handle(env, actor)
			for prop_name, prop_attrs in actor_properties.items():
				if prop_name == 'color':
					continue  # this is set randomly
				props = param_getters_map[prop_name](env, handle)
				if not isinstance(props, list):
					props = [props]
				for prop_idx, prop in enumerate(props):
					for attr, attr_randomization_params in prop_attrs.items():
						name = prop_name+'_' + str(prop_idx) + '_'+attr
						lo_hi = attr_randomization_params['range']
						distr = attr_randomization_params['distribution']
						if 'uniform' not in distr:
							lo_hi = (-1.0*float('Inf'), float('Inf'))
						if isinstance(prop, np.ndarray):
							for attr_idx in range(prop[attr].shape[0]):
								params.append(prop[attr][attr_idx])
								names.append(name+'_'+str(attr_idx))
								lows.append(lo_hi[0])
								highs.append(lo_hi[1])
						else:
							params.append(getattr(prop, attr))
							names.append(name)
							lows.append(lo_hi[0])
							highs.append(lo_hi[1])
		return params, names, lows, highs

	def apply_randomizations(self, dr_params):
		"""Apply domain randomizations to the environment.

		Note that currently we can only apply randomizations only on resets, due to current PhysX limitations

		Args:
			dr_params: parameters for domain randomization to use.
		"""

		# If we don't have a randomization frequency, randomize every step
		rand_freq = dr_params.get("frequency", 1)

		# First, determine what to randomize:
		#   - non-environment parameters when > frequency steps have passed since the last non-environment
		#   - physical environments in the reset buffer, which have exceeded the randomization frequency threshold
		#   - on the first call, randomize everything
		self.last_step = self.gym.get_frame_count(self.sim)
		if self.first_randomization:
			do_nonenv_randomize = True
			env_ids = list(range(self.num_envs))
		else:
			do_nonenv_randomize = (self.last_step - self.last_rand_step) >= rand_freq
			rand_envs = torch.where(self.randomize_buf >= rand_freq, torch.ones_like(self.randomize_buf), torch.zeros_like(self.randomize_buf))
			rand_envs = torch.logical_and(rand_envs, self.reset_buf)
			env_ids = torch.nonzero(rand_envs, as_tuple=False).squeeze(-1).tolist()
			self.randomize_buf[rand_envs] = 0

		if do_nonenv_randomize:
			self.last_rand_step = self.last_step

		param_setters_map = get_property_setter_map(self.gym)
		param_setter_defaults_map = get_default_setter_args(self.gym)
		param_getters_map = get_property_getter_map(self.gym)

		# On first iteration, check the number of buckets
		if self.first_randomization:
			check_buckets(self.gym, self.envs, dr_params)

		for nonphysical_param in ["observations", "actions"]:
			if nonphysical_param in dr_params and do_nonenv_randomize:
				dist = dr_params[nonphysical_param]["distribution"]
				op_type = dr_params[nonphysical_param]["operation"]
				sched_type = dr_params[nonphysical_param]["schedule"] if "schedule" in dr_params[nonphysical_param] else None
				sched_step = dr_params[nonphysical_param]["schedule_steps"] if "schedule" in dr_params[nonphysical_param] else None
				op = operator.add if op_type == 'additive' else operator.mul

				if sched_type == 'linear':
					sched_scaling = 1.0 / sched_step * \
						min(self.last_step, sched_step)
				elif sched_type == 'constant':
					sched_scaling = 0 if self.last_step < sched_step else 1
				else:
					sched_scaling = 1

				if dist == 'gaussian':
					mu, var = dr_params[nonphysical_param]["range"]
					mu_corr, var_corr = dr_params[nonphysical_param].get("range_correlated", [0., 0.])

					if op_type == 'additive':
						mu *= sched_scaling
						var *= sched_scaling
						mu_corr *= sched_scaling
						var_corr *= sched_scaling
					elif op_type == 'scaling':
						var = var * sched_scaling  # scale up var over time
						mu = mu * sched_scaling + 1.0 * \
							(1.0 - sched_scaling)  # linearly interpolate

						var_corr = var_corr * sched_scaling  # scale up var over time
						mu_corr = mu_corr * sched_scaling + 1.0 * \
							(1.0 - sched_scaling)  # linearly interpolate

					def noise_lambda(tensor, param_name=nonphysical_param):
						params = self.dr_randomizations[param_name]
						corr = params.get('corr', None)
						if corr is None:
							corr = torch.randn_like(tensor)
							params['corr'] = corr
						corr = corr * params['var_corr'] + params['mu_corr']
						return op(
							tensor, corr + torch.randn_like(tensor) * params['var'] + params['mu'])

					self.dr_randomizations[nonphysical_param] = {'mu': mu, 'var': var, 'mu_corr': mu_corr, 'var_corr': var_corr, 'noise_lambda': noise_lambda}

				elif dist == 'uniform':
					lo, hi = dr_params[nonphysical_param]["range"]
					lo_corr, hi_corr = dr_params[nonphysical_param].get("range_correlated", [0., 0.])

					if op_type == 'additive':
						lo *= sched_scaling
						hi *= sched_scaling
						lo_corr *= sched_scaling
						hi_corr *= sched_scaling
					elif op_type == 'scaling':
						lo = lo * sched_scaling + 1.0 * (1.0 - sched_scaling)
						hi = hi * sched_scaling + 1.0 * (1.0 - sched_scaling)
						lo_corr = lo_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)
						hi_corr = hi_corr * sched_scaling + 1.0 * (1.0 - sched_scaling)

					def noise_lambda(tensor, param_name=nonphysical_param):
						params = self.dr_randomizations[param_name]
						corr = params.get('corr', None)
						if corr is None:
							corr = torch.randn_like(tensor)
							params['corr'] = corr
						corr = corr * (params['hi_corr'] - params['lo_corr']) + params['lo_corr']
						return op(tensor, corr + torch.rand_like(tensor) * (params['hi'] - params['lo']) + params['lo'])

					self.dr_randomizations[nonphysical_param] = {'lo': lo, 'hi': hi, 'lo_corr': lo_corr, 'hi_corr': hi_corr, 'noise_lambda': noise_lambda}

		if "sim_params" in dr_params and do_nonenv_randomize:
			prop_attrs = dr_params["sim_params"]
			prop = self.gym.get_sim_params(self.sim)

			if self.first_randomization:
				self.original_props["sim_params"] = {
					attr: getattr(prop, attr) for attr in dir(prop)}

			for attr, attr_randomization_params in prop_attrs.items():
				apply_random_samples(
					prop, self.original_props["sim_params"], attr, attr_randomization_params, self.last_step)

			self.gym.set_sim_params(self.sim, prop)

		# If self.actor_params_generator is initialized: use it to
		# sample actor simulation params. This gives users the
		# freedom to generate samples from arbitrary distributions,
		# e.g. use full-covariance distributions instead of the DR's
		# default of treating each simulation parameter independently.
		extern_offsets = {}
		if self.actor_params_generator is not None:
			for env_id in env_ids:
				self.extern_actor_params[env_id] = \
					self.actor_params_generator.sample()
				extern_offsets[env_id] = 0

		for actor, actor_properties in dr_params["actor_params"].items():
			for env_id in env_ids:
				env = self.envs[env_id]
				handle = self.gym.find_actor_handle(env, actor)
				extern_sample = self.extern_actor_params[env_id]

				for prop_name, prop_attrs in actor_properties.items():
					if prop_name == 'color':
						num_bodies = self.gym.get_actor_rigid_body_count(
							env, handle)
						for n in range(num_bodies):
							self.gym.set_rigid_body_color(env, handle, n, gymapi.MESH_VISUAL,
														  gymapi.Vec3(random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
						continue
					if prop_name == 'scale':
						setup_only = prop_attrs.get('setup_only', False)
						if (setup_only and not self.sim_initialized) or not setup_only:
							attr_randomization_params = prop_attrs
							sample = generate_random_samples(attr_randomization_params, 1,
															 self.last_step, None)
							og_scale = 1
							if attr_randomization_params['operation'] == 'scaling':
								new_scale = og_scale * sample
							elif attr_randomization_params['operation'] == 'additive':
								new_scale = og_scale + sample
							self.gym.set_actor_scale(env, handle, new_scale)
						continue

					prop = param_getters_map[prop_name](env, handle)
					set_random_properties = True
					if isinstance(prop, list):
						if self.first_randomization:
							self.original_props[prop_name] = [
								{attr: getattr(p, attr) for attr in dir(p)} for p in prop]
						for p, og_p in zip(prop, self.original_props[prop_name]):
							for attr, attr_randomization_params in prop_attrs.items():
								setup_only = attr_randomization_params.get('setup_only', False)
								if (setup_only and not self.sim_initialized) or not setup_only:
									smpl = None
									if self.actor_params_generator is not None:
										smpl, extern_offsets[env_id] = get_attr_val_from_sample(
											extern_sample, extern_offsets[env_id], p, attr)
									apply_random_samples(
										p, og_p, attr, attr_randomization_params,
										self.last_step, smpl)
								else:
									set_random_properties = False
					else:
						if self.first_randomization:
							self.original_props[prop_name] = deepcopy(prop)
						for attr, attr_randomization_params in prop_attrs.items():
							setup_only = attr_randomization_params.get('setup_only', False)
							if (setup_only and not self.sim_initialized) or not setup_only:
								smpl = None
								if self.actor_params_generator is not None:
									smpl, extern_offsets[env_id] = get_attr_val_from_sample(
										extern_sample, extern_offsets[env_id], prop, attr)
								apply_random_samples(
									prop, self.original_props[prop_name], attr,
									attr_randomization_params, self.last_step, smpl)
							else:
								set_random_properties = False

					if set_random_properties:
						setter = param_setters_map[prop_name]
						default_args = param_setter_defaults_map[prop_name]
						setter(env, handle, prop, *default_args)

		if self.actor_params_generator is not None:
			for env_id in env_ids:  # check that we used all dims in sample
				if extern_offsets[env_id] > 0:
					extern_sample = self.extern_actor_params[env_id]
					if extern_offsets[env_id] != extern_sample.shape[0]:
						print('env_id', env_id,
							  'extern_offset', extern_offsets[env_id],
							  'vs extern_sample.shape', extern_sample.shape)
						raise Exception("Invalid extern_sample size")

		self.first_randomization = False

