"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app

"""Rest everything follows."""

import torch
import h5py
from pathlib import Path
from einops import rearrange
from p4rl.tasks.locomotion.velocity.config.anymal_d.agents.rsl_rl_ppo_cfg import AnymalDRoughPPORunnerCfg
from rsl_rl.rsl_rl.modules.actor_critic import ExtendableActorCritic
from rsl_rl.rsl_rl.addons.kinematics.modules import KinematicSubmoduleConfig


# Load Data from HDF5
def load_h5_dataset(h5_file):
    with h5py.File(h5_file, "r") as f:
        obs_except_command = torch.tensor(f["proprioceptive_obs"], dtype=torch.float32)
        command = torch.tensor(f["command_and_last_action"], dtype=torch.float32)
    return obs_except_command, command

def main():

    num_obs = 48
    num_actions = 12
    num_critic_obs = 48

    data_source_tasks = ["loco", "pedi"] 
    # policy_types = ["loco", "pedi", "scratch"]
    policy_types = ["kine"]

    device = "cuda"

    actor_critic = ExtendableActorCritic(
        num_obs, num_critic_obs, num_actions, 
        direct_pathway_dim=48, 
        final_mlp_dims=[128, 128, 128],
        init_noise_std=1.0,
        activation="elu",
        submodule_configs=[
            KinematicSubmoduleConfig(
                input_dim=12,
                num_bodies=8,
                num_output_features_per_body=3,
                backbone_output_dim=30,
                input_slice=(9, 21),  # must be consistent with the ObservationCfg!
                weight_path="./logs/pretrain/kinematic_mlp_4_layer_out_30.pt"
            ).to_dict()
        ]
    ).to(device)
    
    
    paths = {
        "pedi": "logs/pretrain/CDAC-pedi-baseline-2.0.pt",
        "loco": "logs/pretrain/CDAC-loco-baseline-2.0.pt",
        "scratch": "logs/pretrain/CDAC-scratch-2.0.pt",
        "kine": "logs/selected_weights/EAC-kine-30.pt"
    }

    for data_source_task in data_source_tasks:
        for policy_type in policy_types:


            policy_path = paths[policy_type]
            loaded_dict = torch.load(policy_path, weights_only=False)

            actor_critic.load_state_dict(loaded_dict["model_state_dict"])

            policy_collecting_dataset = Path(paths[data_source_task])
            # Load the dataset
            obs_except_command, command = load_h5_dataset(f"./logs/analysis/data/{policy_collecting_dataset.stem}.h5")
            obs_except_command = obs_except_command.to(device)
            command = command.to(device)

            reconstructed_obs = torch.cat(
                [obs_except_command, command], dim=-1
            )

            with torch.inference_mode():
                # actor_latent = actor_critic.get_actor_latent(reconstructed_obs, each_layer=True)
                # critic_latent = actor_critic.get_critic_latent(reconstructed_obs, each_layer=True)
                actor_latent = actor_critic.actor.submodules["kinematic"].get_latents(reconstructed_obs) # list[torch.Tensor shaped [B, dim]]
                # critic_latent = actor_critic.critic.submodules["kinematic"].get_latents(reconstructed_obs) # list[torch.Tensor shaped [B, dim]]

            with h5py.File(f"./logs/analysis/{data_source_task}_obs_in_{policy_type}_latent.h5", "w") as f:
                    f.create_dataset("proprioceptive_obs", data=obs_except_command.cpu().numpy())
                    f.create_dataset("command_and_last_action", data=command.cpu().numpy())
                    for i, latent in enumerate(actor_latent):
                        f.create_dataset(f"actor_latent_{i}", data=latent.cpu().numpy())
                    # f.create_dataset("critic_latent", data=critic_latent.cpu().numpy())

            print(f"[INFO] ./logs/analysis/{data_source_task}_obs_in_{policy_type}_latent.h5 saved.")


if __name__ == "__main__":
    # run the main function
    main()
    # close sim app
    simulation_app.close()