"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_root_dir", type=str, default="logs/datasets/dynamics/pedipulation_vanilla_EAC_only_initial_2M/", help="Path to HDF5 dataset")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
parser.add_argument("--weights_path", type=str, default="logs/pretrain/dynamics/only_initial/dynamics_joints_plus_base_only_initial_2M_aligned_epoch_20.pt", help="Path to the trained model")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""


import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange
from rsl_rl.rsl_rl.addons.dynamics.data_utils import DynamicsDataset
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def mean_distance(y_pred, y_gt):
    return torch.mean(torch.abs(y_pred - y_gt))

@torch.no_grad()
def normalized_error(y_pred, y_gt):
    normalized_error_per_sample = torch.abs(y_pred - y_gt) / torch.abs(y_gt)

def plot_gt_with_errors(gt_abs, normalized_error, error, num_bins=20, fig_name="logs/analysis/plots/dyna_model_error_detail.png"):
    """
    Plots a histogram of `gt` and overlays line plots of mean normalized_error and mean error.
    
    Parameters:
    - gt (array): Ground truth values, a 1D array of shape [D,].
    - normalized_error (array): Normalized error values, a 1D array of shape [D,].
    - error (array): Error values, a 1D array of shape [D,].
    - num_bins (int): The number of bins to use for the histogram (default is 10).
    """

    # gt_abs = np.array(gt_abs)
    # normalized_error = np.array(normalized_error)
    # error = np.array(error)
    
    # Compute the histogram of gt
    counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 0.35))

    # Compute the bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Prepare arrays to hold the mean errors per bin
    mean_normalized_error = np.zeros_like(bin_centers)
    mean_error = np.zeros_like(bin_centers)

    # Compute mean normalized error and mean error per bin
    for i in range(len(bin_edges) - 1):
        # Find indices of gt that fall into the current bin
        mask = (gt_abs >= bin_edges[i]) & (gt_abs < bin_edges[i + 1])
        if np.any(mask):
            mean_normalized_error[i] = np.mean(normalized_error[mask])
            mean_error[i] = np.mean(error[mask])
        else:
            mean_normalized_error[i] = np.nan
            mean_error[i] = np.nan

    # Plot the histogram
    fig, ax1 = plt.subplots()

    # Histogram on the left y-axis
    color = 'tab:blue'
    ax1.set_xlabel('Absolute GT Value')
    ax1.set_ylabel('Count', color=color)
    ax1.bar(bin_centers, counts, width=np.diff(bin_edges), color=color, alpha=0.6, label='GT Histogram')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_ylim(0, 2500)

    # Line plots on a secondary y-axis
    ax2 = ax1.twinx()

    color1 = 'tab:red'
    color2 = 'tab:green'
    ax2.set_ylabel('Mean Error', color='black')
    ax2.set_ylim(0, 0.4)
    ax2.plot(bin_centers, mean_normalized_error, color=color1, marker='o', label='Mean Normalized Error (percentage)')
    ax2.plot(bin_centers, mean_error, color=color2, marker='x', label='Mean Error (radian)')
    ax2.tick_params(axis='y', labelcolor='black')

    # Add legends
    fig.tight_layout()
    fig.legend(loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    # Show the plot
    plt.title(fig_name.split("/")[-1])
    plt.grid(True)
    plt.savefig(fig_name)

def train_criterion(y_pred, y_gt):
    return nn.functional.smooth_l1_loss(y_pred, y_gt, beta=0.001)

# def make_model_input_output(obs_tensor, actions_tensor, device="cuda"):
#     """
#     Make model input from the output of dataloader.
#     args:
#         obs_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_obs]
#         actions_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_actions]
#     """
#     # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
#     # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]
#     obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, :-1]
#     Y = obs_tensor[:, -1] - obs_tensor[:, -2]
#     return obs_i.to(device), act_i.to(device), Y.to(device)

def make_model_input_output(obs_tensor, actions_tensor, device="cuda", noise_magnitudes=None, mode="joint_only"):
    """
    Make model input from the output of dataloader.
    args:
        obs_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_obs]
        actions_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_actions]
        device: device to use
        noise_magnitudes: noise magnitudes for the model input
        mode: "joint_only" or "joint_plus_base"
    """
    # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
    # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]
    if mode == "joint_only":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    elif mode == "joint_plus_base":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")
    if noise_magnitudes is not None:
        assert noise_magnitudes.shape[0] == obs_i.shape[-1], "Noise magnitude should match the output dimension"
        noise = (torch.rand_like(obs_i) * 2 - 1)*noise_magnitudes.unsqueeze(0).unsqueeze(0)
        obs_i = obs_i + noise
    return obs_i.to(device), act_i.to(device), Y.to(device)


# Training Function
def validate_it(model, it_ns, dataset: DynamicsDataset, model_name):

    # Validation phase
    model.eval()
    for i, it_n in tqdm(enumerate(it_ns)):
        fig_name = "logs/analysis/plots/"+model_name+"_buffer_"+str(it_n)+".png"
        # fig_name = "logs/analysis/plots/"+model_name+"_buffer_"+str(i+1)+"k.png"
        with torch.no_grad():
            obs_tensor, actions_tensor = dataset.get_sample_entries_in_file(it_n)
            obs_i, act_i, Y = make_model_input_output(obs_tensor, actions_tensor, mode="joint_plus_base")
            predictions = model(obs_i, act_i)
            dist = mean_distance(predictions[..., :12], Y[..., :12]) # only analyze the error of joint positions

            error = torch.abs(predictions[..., :12] - Y[..., :12])
            normalized_error = error / torch.abs(Y[..., :12])
            gt_abs = torch.abs(Y[..., :12])

            plot_gt_with_errors(gt_abs.cpu().numpy(), normalized_error.cpu().numpy(), error.cpu().numpy(), fig_name=fig_name)

        



def main():
    ds = DynamicsDataset(args.dataset_root_dir)

    # Define model, loss, and optimizer
    model = resolve_pretrained_module(DynamicsSubmoduleConfig(
                                            input_dim_states=21, 
                                            input_dim_actions=12, 
                                            input_timesteps=5,
                                            output_dim=18, 
                                            hidden_dims=[512, 256, 128], 
                                            representation_dim=64, 
                                            input_slice_states=[12, 72], # TODO
                                            input_slice_actions=[84, 132], # TODO
                                            input_slice_policy=[270, 318], # TODO
                                            backbone_output_dim=256,
                                            ))
    
    # load model weights
    model.load_state_dict(torch.load(args.weights_path))

    # criterion = nn.MSELoss()
    # criterion = nn.SmoothL1Loss(beta=0.02)
    criterion = train_criterion

    # range(0, 20, 5)
    # [108, 118, 128, 138]

    # Train model
    validate_it(model, range(0, 20, 5), ds, "dyna_model_all")

    pass

    # Plot

# Run script
if __name__ == "__main__":
    main()

# python ./rsl_rl/rsl_rl/addons/kinematics/train_kinematics.py --run_name kinematic_mlp_v3