"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_path", type=str, default="logs/datasets/dynamics_analysis/pedi_it_4000_sample_10k_new.h5", help="Path to HDF5 dataset")
parser.add_argument("--load_model_dir", type=str, default="p4rl_assets/dynamics_analysis_base_models/vanilla", help="Path to the RL trained model")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=10000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.rsl_rl.modules.actor_critic import ActorCriticForAnalysis
from einops import rearrange, repeat
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset
from typing import Dict
from p4rl.rsl_rl.rl_cfg import (
    RslRlPpoActorCriticForAnalysisCfg
)
from scipy.stats import norm


def get_target(s_t_plus_1, s_t):
    """
    Return the change of joint positions between two time steps.
    """
    samples = s_t_plus_1 - s_t
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(9, 21))
    samples = samples[..., indices]
    return samples


def subset_loss_l1(mu_pred, s_t_plus_1, s_t, beta=0.002):
    """
    Compute the subset loss via smooth L1 loss.
    args:
        mu_pred: predicted mean, [num_samples, dim]
        s_t_plus_1 and s_t: both shaped [num_samples, dim]
    """
    mu_gt = get_target(s_t_plus_1, s_t) # [num_subsets, num_samples, dim]
    # loss = nn.functional.smooth_l1_loss(mu_pred, mu_gt, beta=beta)
    loss = nn.functional.l1_loss(mu_pred, mu_gt, reduction='mean')
    return loss


def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) for b in batch]
    return batch.to(device)

# Training Function
def train(model: ActorCriticForAnalysis, train_loader, val_loader, optimizer, epochs, run_name):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
    """

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        for batch in tqdm(train_loader):
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            optimizer.zero_grad()
            mu_pred = model.get_dynamic_predictions(s_t)
            loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss}
            wandb.log(logging_dict)

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch = to_device(batch, "cuda")
                s_t, a_t, s_t_plus_1 = batch
                mu_pred = model.get_dynamic_predictions(s_t)
                loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
                logging_dict = {"epoch": epoch+1, "val_loss": loss}
                wandb.log(logging_dict)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss})

        # training too fast, does not need to save model every epoch
        # torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")


# Main Function with Argument Parsing
def main():

    torch.manual_seed(24)

    model_nums = [0, 1000, 2000, 3000, 3999]  # List of model indices to load
    data_nums = [0, 1000, 2000, 3000, 4000]  # List of data indices to load

    model_save_dir = "./logs/pretrain/dynamics_analysis/series_exp_midterm"

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    # for i in range(500, 4000, 250):
    for i in range(0, 1):

        ds = DynamicsAnalysisDataset(f"logs/datasets/dynamics_analysis/pedi_it_{data_nums[i]}_sample_10k_new.h5")

        total_number_of_samples = len(ds)
        print(f"Total number of samples in the dataset: {total_number_of_samples}")

        ratio_of_training_samples = args.training_samples_number / total_number_of_samples

        ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples]) 
        # ds_subset = ds
        print(f"Total samples for training and validation combined: {len(ds_subset)}")

        train_dataset, val_dataset = random_split(ds_subset, [1.0/1.1, 1-1.0/1.1])
        # Create DataLoaders
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)

        ###########################################################################################################

        for j in range(0,3):

            # run_name = f"NEW_model_it_{model_nums[i]}_data_it_{data_nums[i]}_10k_samples_input_{j}"
            run_name = f"NEW_model_it_-1_data_it_{data_nums[i]}_10k_samples_input_{j}"

            # Initialize wandb and log configurations
            wandb.init(project="dynamics_analysis_", name=run_name)
            # Log argparser arguments
            wandb.config.update(vars(args))  # Log all arguments from argparse

            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            policy_cfg = RslRlPpoActorCriticForAnalysisCfg(
                init_noise_std=1.0,
                actor_hidden_dims=[256, 256, 256],
                critic_hidden_dims=[256, 256, 256],
                activation="elu",
                layer_to_dynamics=[j, ],
                dim_dynamics_hidden=64,
                dim_dynamics_prediction=12,
                
            )

            # Log model configuration
            wandb.config.update(policy_cfg.to_dict())  # Log the ActorCriticForAnalysis configuration

            model = ActorCriticForAnalysis(
                num_actor_obs=48,
                num_critic_obs=48,
                num_actions=12,
                **policy_cfg.to_dict()
                ).to(device)

            load_model_path = os.path.join(args.load_model_dir, f"model_{model_nums[i]}.pt")
            # Load weights
            model.load_trunk(load_model_path)

            optimizer = optim.Adam(model.parameters(), lr=args.lr)
            # Train model
            train(model, train_loader, val_loader, optimizer, args.epochs, run_name)

            # Save trained model
            torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
            print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

            wandb.finish()

# Run script
if __name__ == "__main__":
    main()


"""

python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer_series_exp.py \
--headless \
--num_envs 1

"""