"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_path", type=str, default="logs/datasets/dynamics_analysis/pedi_it_4k_sample_10k_new.h5", help="Path to HDF5 dataset")
parser.add_argument("--load_model_path", type=str, default="logs/pretrain/dynamics_analysis/begin_analysis_0_10k_samples_new_input_0/begin_analysis_0_10k_samples_new_input_0.pt", help="Path to the RL trained model")
parser.add_argument("--training_samples_number", type=int, default=10000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.rsl_rl.modules.actor_critic import ActorCriticForAnalysis
from einops import rearrange, repeat
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset
from typing import Dict
from p4rl.rsl_rl.rl_cfg import (
    RslRlPpoActorCriticForAnalysisCfg
)
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from typing import Sequence, Union


def get_target(s_t_plus_1, s_t):
    """
    Return the change of joint positions between two time steps.
    """
    samples = s_t_plus_1 - s_t[:, None, :]
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(9, 21))
    samples = samples[..., indices]
    return samples

def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) for b in batch]
    return batch.to(device)

def plot_subset(mu_pred, samples, vis_dim: Sequence[int], save_dir="./logs/analysis/dynamics_analysis_plots_0/", sample_idx=0):
    """
    Plot the subset of the samples along with the gaussian distribution predicted by the dynamics layer.
    
    Args:
        mu_pred (np.ndarray): Predicted mean, shape [1, dim]
        sigma_pred (np.ndarray): Predicted standard deviation, shape [1, dim]
        samples (np.ndarray): Samples from the dataset, shape [1, subset_size, dim]
        vis_dim (list): Dimensions to visualize
        save_dir (str): Directory to save the plots
    """
    os.makedirs(save_dir, exist_ok=True)
    samples = samples[0].to('cpu').numpy()  # shape: [subset_size, dim]
    mu_pred = mu_pred[0].to('cpu').numpy()  # shape: [dim]
    # sigma_pred = sigma_pred[0].to('cpu').numpy()  # shape: [dim]

    num_dims = len(vis_dim)
    fig, axes = plt.subplots(1, num_dims, figsize=(5 * num_dims, 4))

    if num_dims == 1:
        axes = [axes]  # ensure axes is iterable

    for i, dim in enumerate(vis_dim):
        ax1 = axes[i]
        ax2 = ax1.twinx()  # create second y-axis

        sample_values = samples[:, dim]
        mu = mu_pred[dim]
        # sigma = sigma_pred[dim]

        common_bins = np.linspace(-0.5, 0.5, 50 + 1)
        # Plot histogram (left y-axis)
        counts, bins, patches = ax1.hist(sample_values, bins=common_bins, alpha=0.6, color='skyblue', label='Samples')

        # fit a gaussian distribution to the samples
        mu_gt, sigma_gt = norm.fit(sample_values)
        x_vals = np.linspace(sample_values.min() - 3*sigma_gt, sample_values.max() + 3*sigma_gt, 500)
        y_gt_dist = norm.pdf(x_vals, mu_gt, sigma_gt)

        ax2.plot(x_vals, y_gt_dist, 'b-', lw=2, label='GT Gaussian')

        # Labeling
        ax1.set_title(f"Dimension {dim}")
        ax1.set_xlabel("Value")
        ax1.set_ylabel("Frequency (Histogram)", color='skyblue')

        ax1.axvline(mu, color='red', linestyle='-', label='Predicted Mean')
        # ax1.axvline(mu + sigma, color='red', linestyle='--', label='Predicted Std Dev')
        # ax1.axvline(mu - sigma, color='red', linestyle='--')

        ax1.axvline(mu_gt, color='blue', linestyle='-', label='GT Mean')
        ax1.axvline(mu_gt + sigma_gt, color='blue', linestyle='--', label='GT Std Dev')
        ax1.axvline(mu_gt - sigma_gt, color='blue', linestyle='--')

        ax2.set_ylabel("PDF (Gaussian)", color='red')
        ax2.set_ylim(0, 1.2 * max(y_gt_dist))  # Adjust y-axis limits for Gaussian
        ax1.set_ylim(0, counts.max() * 1.2)  # Adjust y-axis limits for histogram
        ax1.set_xlim(min(-0.5, sample_values.min()-0.2), max(0.5, sample_values.max()+0.2))  # Adjust x-axis limits

        # Match tick color to axis
        ax1.tick_params(axis='y', labelcolor='skyblue')
        ax2.tick_params(axis='y', labelcolor='red')

        # grid
        ax2.grid(True, linestyle='--', alpha=0.5)

    plt.tight_layout()
    save_path = os.path.join(save_dir, f"sample_{sample_idx}_gaussian_histo_{vis_dim}.png")
    plt.savefig(save_path)
    plt.close()
    print(f"Saved plot to: {save_path}")


# Training Function
def error_vis(model: ActorCriticForAnalysis, val_loader, vis_num=10):
    """
    Train the model.
    args:
        model: the model to train
        val_loader: DataLoader for validation data

    """

    # Validation phase
    model.eval()
    with torch.no_grad():
        val_iter = iter(val_loader)
        for i in range(vis_num):
            batch = next(val_iter)
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            mu_pred = model.get_dynamic_predictions(s_t)
            # loss = subset_loss(mu_pred, std_pred, s_t_plus_1, s_t)
            # # linear velocity
            # plot_subset(mu_pred, get_target(s_t_plus_1, s_t), vis_dim=range(0, 3), sample_idx=i)
            # # angular velocity
            # plot_subset(mu_pred, get_target(s_t_plus_1, s_t), vis_dim=range(3, 6), sample_idx=i)
            # HAA
            plot_subset(mu_pred, get_target(s_t_plus_1, s_t), vis_dim=range(0, 4), sample_idx=i)
            # HFE
            plot_subset(mu_pred, get_target(s_t_plus_1, s_t), vis_dim=range(4, 8), sample_idx=i)
            # KFE
            plot_subset(mu_pred, get_target(s_t_plus_1, s_t), vis_dim=range(8, 12), sample_idx=i)


# Main Function with Argument Parsing
def main():

    ds = DynamicsAnalysisDataset(args.dataset_path)
    generator = torch.Generator().manual_seed(42)

    total_number_of_samples = len(ds)
    print(f"Total number of samples in the dataset: {total_number_of_samples}")

    ratio_of_training_samples = args.training_samples_number / total_number_of_samples

    ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples], generator=generator) 
    # ds_subset = ds
    print(f"Total samples for training and validation combined: {len(ds_subset)}")

    train_dataset, val_dataset = random_split(ds_subset, [0.9, 0.1], generator=generator)

    # NOTE: Create using TRAIN dataset or VAL dataset?!
    val_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=16)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    policy_cfg = RslRlPpoActorCriticForAnalysisCfg(
        init_noise_std=1.0,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        layer_to_dynamics=[0, ],
        dim_dynamics_hidden=64,
        dim_dynamics_prediction=12,
        
    )

    model = ActorCriticForAnalysis(
        num_actor_obs=48,
        num_critic_obs=48,
        num_actions=12,
        **policy_cfg.to_dict()
        ).to(device)
    
    # Load weights
    model.load_all(args.load_model_path)
    # Train model
    error_vis(model, val_loader, vis_num=5)


# Run script
if __name__ == "__main__":
    main()


"""


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/error_visualization.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_4k_sample_10k_new.h5 \
--load_model_path logs/pretrain/dynamics_analysis/analysis_model_it_4k_data_it_4k_10k_samples_input_0/analysis_model_it_4k_data_it_4k_10k_samples_input_0.pt \
--training_samples_number 10000 


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/error_visualization.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_0_sample_10k_new.h5 \
--load_model_path logs/pretrain/dynamics_analysis/analysis_model_it_0_data_it_0_10k_samples_input_0/analysis_model_it_0_data_it_0_10k_samples_input_0.pt \
--training_samples_number 10000 

"""