# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.managers import SceneEntityCfg
from isaaclab.utils.math import quat_rotate_inverse, quat_rotate  # , wrap_to_pi, yaw_quat
from isaaclab.assets import RigidObject
# from omni.isaac.contrib_tasks.pedipulation.assets.legged_robot import LeggedRobots
from isaaclab.terrains import TerrainImporter
from isaaclab.managers.command_manager import CommandManager
from isaaclab.managers.observation_manager import ObservationManager
from isaaclab.sensors import ContactSensor
from isaaclab.assets import Articulation
import torch.nn as nn
from typing import TYPE_CHECKING
from isaaclab.utils.math import quat_apply_inverse, yaw_quat


import torch

if TYPE_CHECKING:
    from ..config.anymal_d.exploration_base import P4RLExplorationEnvCfg

def feet_air_time_positive_biped(env, command_name: str, threshold: float, sensor_cfg: SceneEntityCfg) -> torch.Tensor:
    """Reward long steps taken by the feet for bipeds.

    This function rewards the agent for taking steps up to a specified threshold and also keep one foot at
    a time in the air.

    If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.
    """
    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
    # compute the reward
    air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
    contact_time = contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids]
    in_contact = contact_time > 0.0
    in_mode_time = torch.where(in_contact, contact_time, air_time)
    single_stance = torch.sum(in_contact.int(), dim=1) == 1
    reward = torch.min(torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1)[0]
    reward = torch.clamp(reward, max=threshold)
    # no reward for zero command
    reward *= torch.norm(env.command_manager.get_command(command_name)[:, :2], dim=1) > 0.1
    return reward


def feet_slide(env, sensor_cfg: SceneEntityCfg, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
    """Penalize feet sliding.

    This function penalizes the agent for sliding its feet on the ground. The reward is computed as the
    norm of the linear velocity of the feet multiplied by a binary contact sensor. This ensures that the
    agent is penalized only when the feet are in contact with the ground.
    """
    # Penalize feet sliding
    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
    contacts = contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :].norm(dim=-1).max(dim=1)[0] > 1.0
    asset = env.scene[asset_cfg.name]

    body_vel = asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2]
    reward = torch.sum(body_vel.norm(dim=-1) * contacts, dim=1)
    return reward


def track_lin_vel_xy_yaw_frame_exp(
    env, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
    """Reward tracking of linear velocity commands (xy axes) in the gravity aligned robot frame using exponential kernel."""
    # extract the used quantities (to enable type-hinting)
    asset = env.scene[asset_cfg.name]
    vel_yaw = quat_apply_inverse(yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3])
    lin_vel_error = torch.sum(
        torch.square(env.command_manager.get_command(command_name)[:, :2] - vel_yaw[:, :2]), dim=1
    )
    return torch.exp(-lin_vel_error / std**2)


def track_ang_vel_z_world_exp(
    env, command_name: str, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
    """Reward tracking of angular velocity commands (yaw) in world frame using exponential kernel."""
    # extract the used quantities (to enable type-hinting)
    asset = env.scene[asset_cfg.name]
    ang_vel_error = torch.square(env.command_manager.get_command(command_name)[:, 2] - asset.data.root_ang_vel_w[:, 2])
    return torch.exp(-ang_vel_error / std**2)


def build_mlp(input_dims: int, hidden_dims: list[int], output_dims: int):
        """Builds target and predictor networks"""

        network_layers = []
        # resolve hidden dimensions
        # if dims is -1 then we use the number of observations
        hidden_dims = [input_dims if dim == -1 else dim for dim in hidden_dims]
        # resolve activation function
        activation = nn.ELU()
        # first layer
        network_layers.append(nn.Linear(input_dims, hidden_dims[0]))
        network_layers.append(activation)
        # subsequent layers
        for layer_index in range(len(hidden_dims)):
            if layer_index == len(hidden_dims) - 1:
                # last layer
                network_layers.append(nn.Linear(hidden_dims[layer_index], output_dims))
            else:
                # hidden layers
                network_layers.append(nn.Linear(hidden_dims[layer_index], hidden_dims[layer_index + 1]))
                network_layers.append(activation)
        return nn.Sequential(*network_layers)



def rle_reward(env: ManagerBasedRLEnv) -> torch.Tensor:

    command_manager: CommandManager = env.command_manager  # commands given in world frame
    obs_manager: ObservationManager = env.observation_manager
    latent_target = command_manager.get_command("rle_command")  # commands given in world frame
    inv_input = obs_manager._obs_buffer["inv_dynamics_input"]  # commands given in robot base frame

    # forward inv_input through a random MLP

    if not hasattr(rle_reward, "latent_model"):
        rle_reward.latent_model = build_mlp(33, [128, 128], 128).to(inv_input.device)

    latent_state = rle_reward.latent_model(inv_input)

    # return torch.exp(-torch.norm(latent_target - latent_state, dim=1) / sigma)
    reward = torch.clip(torch.sum(latent_target*latent_state, dim=-1), min=0.0) 
    # setting a lower bound of 0.0 for the reward is intentional 
    # because we only want to reward the agent for getting close to the target, 
    # and to avoid punishing the agent for being at state which is far from the target state.
    # punishment can make the agent too conservative and not explore enough.

    return reward


def feet_air_time(
    env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, threshold: float
) -> torch.Tensor:
    """Reward long steps taken by the feet using L2-kernel. Modified from the main locomotion branch. 

    This function rewards the agent for taking longer steps. This helps ensure
    that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of
    the time for which the feet are in the air. 

    P4RL: A linear reward can not encourage the robot to take longer steps,
    E.g., for a certain gait, it does not favor lower frequency (longer steps) than higher frequency (shorter steps) gaits.
    Say the robot is trotting, it will have similar reward for both faster and slower trot, because for the faster trot,
    the reward comes in less quantities but more frequently.
    So we use a L2-kernel to reward longer steps

    """
    # extract the used quantities (to enable type-hinting)
    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
    # compute the reward
    first_contact = contact_sensor.compute_first_contact(env.step_dt)[:, sensor_cfg.body_ids]
    last_air_time = torch.clip(contact_sensor.data.last_air_time[:, sensor_cfg.body_ids], max=threshold)
    last_air_time_l2 = torch.square(last_air_time)
    reward = torch.sum(last_air_time_l2 * first_contact, dim=1)
    return reward
