from __future__ import annotations

from isaaclab.managers import SceneEntityCfg

import torch
from typing import TYPE_CHECKING, Sequence

import carb

from isaaclab.markers import VisualizationMarkers
from isaaclab.markers.config import CUBOID_MARKER_CFG
from isaaclab.terrains import TerrainImporter
import isaaclab.sim as sim_utils
from isaaclab.assets.articulation import Articulation
from isaaclab.managers import CommandTerm
from isaaclab.assets import RigidObject
from isaaclab.utils.math import quat_rotate, quat_from_angle_axis
from isaaclab.markers.visualization_markers import VisualizationMarkersCfg
import omni.usd.commands

import itertools


if TYPE_CHECKING:
    from isaaclab.envs import ManagerBasedRLEnv
    from exploration.mdp.commands.commands_cfg import RLECommandCfg
    from exploration.config.anymal_d.exploration_base import P4RLExplorationEnvCfg


CYLINDER_MARKER_CFG = VisualizationMarkersCfg(
    markers={
        "cylinder": sim_utils.CylinderCfg(
            radius=0.02,
            height=1.0, # needs to be adjusted to the sizes of edges of the command space
            axis="X",
            visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.2, 1.0)),
        )
    }
)


class RLECommand(CommandTerm):
    """ TODO: Description
    """

    cfg: RLECommandCfg
    """Configuration for the command generator."""


    def __init__(self, cfg: RLECommandCfg, env: ManagerBasedRLEnv):
        """Initialize the command generator class.

        Args:
            cfg: The configuration parameters for the command generator.
            env: The environment object.
        """

        self.dim_latent_space = cfg.dim_latent_space
        self.std_sampling = cfg.std_sampling
        self.normalize_vector = cfg.normalize_vector

        # initialize the base class
        super().__init__(cfg, env)

        # self.env_cfg: P4RLExplorationEnvCfg = env.cfg # for type hinting

        # # obtain the robot and terrain assets
        # # -- robot
        # self.robot: Articulation = env.scene[cfg.asset_name]

        # # # -- terrain
        # self.terrain: TerrainImporter = env.scene.terrain

        self.rle_commands = torch.zeros(self.num_envs, self.dim_latent_space, device=self.device)


    def __str__(self) -> str:
        msg = "RLECommand:\n"
        msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
        msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
        # TODO: what foot command specific info to print?
        return msg


    # override the two abstract methods from CommandTerm
    def _update_command(self):
        pass

    def _update_metrics(self):
        pass


    """
    Properties
    """

    @property
    def command(self) -> torch.Tensor:
        """The desired foot position in world frame. Shape is (num_envs, 3)."""
        return self.rle_commands


    """
    Implementation specific functions.
    """


    def _resample_command(self, env_ids: Sequence[int]):
        """Generate commands for RLE exploration."""
        if len(env_ids) == 0:
            return

        # sample a random command from the current command space
        r = torch.randn(len(env_ids), self.dim_latent_space, device=self.device)
        if self.normalize_vector:
            # normalize the sampled vector to unit length
            self.rle_commands[env_ids] = r / torch.norm(r, dim=1, keepdim=True)
        else:
            self.rle_commands[env_ids] = r * self.std_sampling 



    # def _update_command(self): # called right after the command is resampled
    #     pass
