"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_path", type=str, default="logs/datasets/dynamics_analysis/pedi_it_4k_sample_10k_new.h5", help="Path to HDF5 dataset")
parser.add_argument("--load_model_path", type=str, default="p4rl_assets/dynamics_analysis_base_models/vanilla/model_3999.pt", help="Path to the RL trained model")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=1000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--run_name", type=str, default="analysis_0", help="Run name for logging")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

model_save_dir = "./logs/pretrain/dynamics_analysis/" + args.run_name

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.rsl_rl.modules.actor_critic import ActorCriticForAnalysis
from einops import rearrange, repeat
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset
from typing import Dict
from p4rl.rsl_rl.rl_cfg import (
    RslRlPpoActorCriticForAnalysisCfg
)
from scipy.stats import norm


def get_target(s_t_plus_1, s_t):
    """
    Return the change of joint positions between two time steps.
    """
    samples = s_t_plus_1 - s_t
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(9, 21))
    samples = samples[..., indices]
    return samples


def subset_loss_l1(mu_pred, s_t_plus_1, s_t, beta=0.002):
    """
    Compute the subset loss via smooth L1 loss.
    args:
        mu_pred: predicted mean, [num_samples, dim]
        s_t_plus_1 and s_t: both shaped [num_samples, dim]
    """
    mu_gt = get_target(s_t_plus_1, s_t) # [num_subsets, num_samples, dim]
    loss = nn.functional.smooth_l1_loss(mu_pred, mu_gt, beta=beta)
    return loss


def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) for b in batch]
    return batch.to(device)

# Training Function
def train(model: ActorCriticForAnalysis, train_loader, val_loader, optimizer, epochs, run_name):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
    """


    wandb.init(project="dynamics_analysis", name=run_name)

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        zero_epoch_loss = 0.0
        for batch in tqdm(train_loader):
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            optimizer.zero_grad()
            mu_pred = model.get_dynamic_predictions(s_t)
            loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
            zero_loss = subset_loss_l1(torch.zeros_like(mu_pred), s_t_plus_1, s_t)
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss}
            wandb.log(logging_dict)

            epoch_loss += loss.item()
            zero_epoch_loss += zero_loss.item()

        avg_loss = epoch_loss / len(train_loader)
        avg_zero_loss = zero_epoch_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch = to_device(batch, "cuda")
                s_t, a_t, s_t_plus_1 = batch
                mu_pred = model.get_dynamic_predictions(s_t)
                loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
                logging_dict = {"epoch": epoch+1, "val_loss": loss}
                wandb.log(logging_dict)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss, "zero_epoch_loss": avg_zero_loss})

        # training too fast, does not need to save model every epoch
        # torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")

    wandb.finish()


# Main Function with Argument Parsing
def main():

    # Initialize wandb and log configurations
    wandb.init(project="dynamics_analysis", name=args.run_name)
    # Log argparser arguments
    wandb.config.update(vars(args))  # Log all arguments from argparse

    ds = DynamicsAnalysisDataset(args.dataset_path)
    generator = torch.Generator().manual_seed(42)

    total_number_of_samples = len(ds)
    print(f"Total number of samples in the dataset: {total_number_of_samples}")

    ratio_of_training_samples = args.training_samples_number / total_number_of_samples

    ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples], generator=generator) 
    # ds_subset = ds
    print(f"Total samples for training and validation combined: {len(ds_subset)}")

    train_dataset, val_dataset = random_split(ds_subset, [1.0/1.1, 1-1.0/1.1], generator=generator)
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    run_name = args.run_name

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    policy_cfg = RslRlPpoActorCriticForAnalysisCfg(
        init_noise_std=1.0,
        actor_hidden_dims=[256, 256, 256],
        critic_hidden_dims=[256, 256, 256],
        activation="elu",
        layer_to_dynamics=[0, ],
        dim_dynamics_hidden=64,
        dim_dynamics_prediction=12,
        
    )

     # Log model configuration
    wandb.config.update(policy_cfg.to_dict())  # Log the ActorCriticForAnalysis configuration

    model = ActorCriticForAnalysis(
        num_actor_obs=48,
        num_critic_obs=48,
        num_actions=12,
        **policy_cfg.to_dict()
        ).to(device)
    
    # Load weights
    model.load_trunk(args.load_model_path)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # Train model
    train(model, train_loader, val_loader, optimizer, args.epochs, run_name)

    # Save trained model
    torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
    print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

# Run script
if __name__ == "__main__":
    main()


"""


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_0_sample_10k_new.h5 \
--load_model_path p4rl_assets/dynamics_analysis_base_models/vanilla/model_0.pt \
--epochs 20 \
--batch_size 128 \
--training_samples_number 10000 \
--lr 0.001 \
--run_name analysis_model_it_0_data_it_0_10k_samples_input_0


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_4k_sample_10k_new.h5 \
--load_model_path p4rl_assets/dynamics_analysis_base_models/vanilla/model_3999.pt \
--epochs 20 \
--batch_size 128 \
--training_samples_number 10000 \
--lr 0.001 \
--run_name analysis_model_it_4k_data_it_4k_10k_samples_input_0


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_4k_sample_10k_new.h5 \
--load_model_path p4rl_assets/dynamics_analysis_base_models/vanilla/model_0.pt \
--epochs 20 \
--batch_size 128 \
--training_samples_number 10000 \
--lr 0.001 \
--run_name analysis_model_it_0_data_it_4k_10k_samples_input_0


python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_0_sample_10k_new.h5 \
--load_model_path p4rl_assets/dynamics_analysis_base_models/vanilla/model_3999.pt \
--epochs 20 \
--batch_size 128 \
--training_samples_number 10000 \
--lr 0.001 \
--run_name analysis_model_it_4k_data_it_0_10k_samples_input_0


"""



"""

python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer.py \
--headless \
--num_envs 1 \
--dataset_path logs/datasets/dynamics_analysis/pedi_it_4000_sample_10k_new.h5 \
--load_model_path p4rl_assets/dynamics_analysis_base_models/vanilla/model_0.pt \
--epochs 20 \
--batch_size 128 \
--training_samples_number 10000 \
--lr 0.001 \
--run_name NEW_model_it_0_data_it_4000_10k_samples_input_0

"""
