"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_path", type=str, default="logs/datasets/dynamics_analysis", help="Path to HDF5 dataset")
parser.add_argument("--load_model_dir", type=str, default="p4rl_assets/dynamics_analysis_base_models/vanilla", help="Path to the RL trained model")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=11000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.rsl_rl.modules.actor_critic import ActorCriticForAnalysis
from einops import rearrange, repeat
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset
from typing import Dict
from p4rl.rsl_rl.rl_cfg import (
    RslRlPpoActorCriticForAnalysisCfg
)
from scipy.stats import norm


def get_target(s_t_plus_1, s_t):
    """
    Return the change of joint positions between two time steps.
    """
    samples = s_t_plus_1 - s_t[:, None, :]
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(9, 21))
    samples = samples[..., indices]
    return samples


# def subset_loss_gaussian(mu_pred, sigma_pred, s_t_plus_1, s_t):
#     """
#     Compute the subset loss as the negative log likelihood loss.
#     args:
#         mu_pred: predicted mean, [num_subsets, dim]
#         sigma_pred: predicted standard deviation, [num_subsets, dim]
#         samples: samples from the dataset [num_subsets, num_samples, dim]
#     """
#     samples = get_target(s_t_plus_1, s_t)
    
#     # Compute the log likelihood of the samples given the predicted mean and standard deviation
#     # using the multivariate normal distribution
#     num_samples = samples.shape[1]
#     dist = torch.distributions.Normal(repeat(mu_pred, 's d -> (s repeat) d', repeat=num_samples), repeat(sigma_pred, 's d -> (s repeat) d', repeat=num_samples))
#     log_likelihood = dist.log_prob(samples.view(-1, samples.shape[-1]))
    
#     # Compute the negative log likelihood loss
#     loss = -torch.mean(log_likelihood)
#     return loss


def subset_loss_l1(mu_pred, s_t_plus_1, s_t, beta=0.02):
    """
    Compute the subset loss via smooth L1 loss.
    args:
        mu_pred: predicted mean, [num_subsets, dim]
        s_t_plus_1 and s_t: both shaped [num_subsets, num_samples, dim]
    """
    samples = get_target(s_t_plus_1, s_t) # [num_subsets, num_samples, dim]
    mu_gt = samples.mean(dim=1)  # [num_subsets, dim]
    loss = nn.functional.smooth_l1_loss(mu_pred, mu_gt, beta=beta)
    return loss


def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) for b in batch]
    return batch.to(device)

# Training Function
def train(model: ActorCriticForAnalysis, train_loader, val_loader, optimizer, epochs, run_name):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
    """

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        for batch in tqdm(train_loader):
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            optimizer.zero_grad()
            mu_pred = model.get_dynamic_predictions(s_t)
            loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss}
            wandb.log(logging_dict)

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch = to_device(batch, "cuda")
                s_t, a_t, s_t_plus_1 = batch
                mu_pred = model.get_dynamic_predictions(s_t)
                loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
                logging_dict = {"epoch": epoch+1, "val_loss": loss}
                wandb.log(logging_dict)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss})

        # training too fast, does not need to save model every epoch
        # torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")


# Main Function with Argument Parsing
def main():

    ds = DynamicsAnalysisDataset(args.dataset_path)
    generator = torch.Generator().manual_seed(42)

    total_number_of_samples = len(ds)
    print(f"Total number of samples in the dataset: {total_number_of_samples}")

    ratio_of_training_samples = args.training_samples_number / total_number_of_samples

    ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples], generator=generator) 
    # ds_subset = ds
    print(f"Total samples for training and validation combined: {len(ds_subset)}")

    train_dataset, val_dataset = random_split(ds_subset, [1.0/1.1, 1-1.0/1.1], generator=generator)
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)

    model_save_dir = "./logs/pretrain/dynamics_analysis/series_exp_mixed_data_10k_samples_input_0"

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    for i in range(500, 4000, 250):

        run_name = f"analysis_model_it_{i}_mixed_data_10k_samples_input_0"
        # Initialize wandb and log configurations
        wandb.init(project="dynamics_analysis", name=run_name)
        # Log argparser arguments
        wandb.config.update(vars(args))  # Log all arguments from argparse

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        policy_cfg = RslRlPpoActorCriticForAnalysisCfg(
            init_noise_std=1.0,
            actor_hidden_dims=[256, 256, 256],
            critic_hidden_dims=[256, 256, 256],
            activation="elu",
            layer_to_dynamics=[0, ],
            dim_dynamics_hidden=64,
            dim_dynamics_prediction=12,
            
        )

        # Log model configuration
        wandb.config.update(policy_cfg.to_dict())  # Log the ActorCriticForAnalysis configuration

        model = ActorCriticForAnalysis(
            num_actor_obs=48,
            num_critic_obs=48,
            num_actions=12,
            **policy_cfg.to_dict()
            ).to(device)

        load_model_path = os.path.join(args.load_model_dir, f"model_{i}.pt")
        # Load weights
        model.load_trunk(load_model_path)

        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        # Train model
        train(model, train_loader, val_loader, optimizer, args.epochs, run_name)

        # Save trained model
        torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
        print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

        wandb.finish()

# Run script
if __name__ == "__main__":
    main()
