"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_root_dir", type=str, default="logs/datasets/dynamics/pedipulation_vanilla_EAC_new_all/", help="Path to HDF5 dataset")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange
from rsl_rl.rsl_rl.addons.dynamics.data_utils import DynamicsDataset
import numpy as np
import matplotlib.pyplot as plt

def main():
    ds = DynamicsDataset(args.dataset_root_dir)

    obs_list = []
    actions_list = []

    idx=ds.find_start_idx_of_file(135) # 138 is the index of the file we want to visualize. You can change this to any index you want to visualize.

    # Load the dataset
    for i in range(24): # we know that at each iteration, every environment has been stepped 24 times. (see num_steps_per_env in rsl_rl_ppo_cfg.py)

        obs_tensor, act_tensor = ds.__getitem__(idx+i) # [6, 21], [6, 12]
        obs_list.append(obs_tensor)
        actions_list.append(act_tensor)

    action_time_series = torch.stack(actions_list, dim=0)[:, -1] # [24, 12]
    obs_time_series = torch.stack(obs_list, dim=0)[:, -1] # [24, 21]

    # visualize the action and obs data separately
    plt.figure(figsize=(20, 10))
    plt.subplot(2, 1, 1)
    plt.plot(action_time_series[:, 0], label="action")
    plt.title("Action Time Series")
    plt.xlabel("Time")
    plt.ylabel("Action Value")
    plt.legend()
    plt.subplot(2, 1, 2)
    plt.plot(obs_time_series[:, 0], label="obs")
    plt.title("Observation Time Series")
    plt.xlabel("Time")
    plt.ylabel("Observation Value")
    plt.legend()
    plt.grid()
    plt.show()
    plt.savefig(f"logs/analysis/plots/dyna_data_{os.path.basename(os.path.dirname(args.dataset_root_dir))}_{idx}.png")

if __name__ == "__main__":
    main()