"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_root_dir", type=str, default="logs/datasets/dynamics/velocity_vanilla_EAC_uniform_interval/", help="Path to HDF5 dataset")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
parser.add_argument("--weights_dir", type=str, default="p4rl_assets/dynamics_modules/", help="Path to the trained model")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange
from rsl_rl.rsl_rl.addons.dynamics.data_utils import DynamicsDataset

@torch.no_grad()
def mean_distance(y_pred, y_gt):
    return torch.mean(torch.abs(y_pred - y_gt))

def train_criterion(y_pred, y_gt):
    return nn.functional.smooth_l1_loss(y_pred, y_gt, beta=0.001)

# def make_model_input_output(obs_tensor, actions_tensor, device="cuda"):
#     """
#     Make model input from the output of dataloader.
#     args:
#         obs_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_obs]
#         actions_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_actions]
#     """
#     # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
#     # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]
#     obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, :-1]
#     Y = obs_tensor[:, -1] - obs_tensor[:, -2]
#     return obs_i.to(device), act_i.to(device), Y.to(device)

def make_model_input_output(obs_tensor, actions_tensor, device="cuda", noise_magnitudes=None, mode="joint_only"):
    """
    Make model input from the output of dataloader.
    args:
        obs_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_obs]
        actions_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_actions]
        device: device to use
        noise_magnitudes: noise magnitudes for the model input
        mode: "joint_only" or "joint_plus_base"
    """
    # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
    # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]
    if mode == "joint_only":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    elif mode == "joint_plus_base":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")
    if noise_magnitudes is not None:
        assert noise_magnitudes.shape[0] == obs_i.shape[-1], "Noise magnitude should match the output dimension"
        noise = (torch.rand_like(obs_i) * 2 - 1)*noise_magnitudes.unsqueeze(0).unsqueeze(0)
        obs_i = obs_i + noise
    return obs_i.to(device), act_i.to(device), Y.to(device)


# Training Function
def validate(model, dataset: DynamicsDataset):
    error_list = []

    for it_n in range(dataset.num_files()):
        # Validation phase
        model.eval()
        val_dist = 0.0
        with torch.no_grad():
            obs_tensor, actions_tensor = dataset.get_sample_entries_in_file(it_n)
            obs_i, act_i, Y = make_model_input_output(obs_tensor, actions_tensor, mode="joint_plus_base")
            predictions = model(obs_i, act_i)
            dist = mean_distance(predictions[..., :12], Y[..., :12]) # only analyze the error of joint positions

            val_dist = dist.item()
            error_list.append(val_dist)
        print(f"Iteration: {it_n}, Validation distance: {val_dist}")

    return error_list

# Main Function with Argument Parsing
def main():
    ds = DynamicsDataset(args.dataset_root_dir)

    # Define model, loss, and optimizer
    model = resolve_pretrained_module(DynamicsSubmoduleConfig(
                                            input_dim_states=21, 
                                            input_dim_actions=12, 
                                            input_timesteps=5,
                                            output_dim=18, 
                                            hidden_dims=[512, 256, 128], 
                                            representation_dim=64, 
                                            input_slice_states=[12, 72], # TODO
                                            input_slice_actions=[96, 144], # TODO
                                            input_slice_policy=[270, 318], # TODO
                                            backbone_output_dim=256,
                                            ))
    
    weights_names = os.listdir(args.weights_dir)
    errors_list = []

    for weight_name in weights_names:
        # load model weights
        model.load_state_dict(torch.load(os.path.join(args.weights_dir, weight_name)))
        # validate
        errors_list.append(validate(model, ds))

    if True:
        # Visualize the predictions
        import matplotlib.pyplot as plt

        fig = plt.figure(figsize=(10, 5))
        plt.plot(np.array(errors_list).transpose())
        plt.title("Mean distance of the model predictions in radian: validated on locomotion D dataset")
        plt.xlabel("saved buffer NO.")
        plt.ylabel("Distance")
        plt.legend([item.split("/")[-1] for item in weights_names])
        plt.grid()
        plt.savefig("logs/analysis/plots/validate_dynamics.png")
    # Plot

# Run script
if __name__ == "__main__":
    main()

# python ./rsl_rl/rsl_rl/addons/kinematics/train_kinematics.py --run_name kinematic_mlp_v3