"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_root_dir", type=str, default="logs/datasets/dynamics/pedipulation_vanilla_EAC_new_all_subset/", help="Path to HDF5 dataset")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=2000000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--run_name", type=str, default="dynamics_joints_plus_base_2M_aligned", help="Run name for logging")
parser.add_argument("--dynamics_mode", type=str, default="joint_plus_base", help="Mode of dynamics model. joint_only or joint_plus_base")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

model_save_dir = "./logs/pretrain/dynamics/" + args.run_name

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
from rsl_rl.addons.dynamics.modules_recurrent import *
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange
from rsl_rl.rsl_rl.addons.dynamics.data_utils import DynamicsDataset
from typing import Dict

@torch.no_grad()
def mean_absolute_distance_error_joint_space(y_pred, y_gt):
    return torch.mean(torch.abs(y_pred[..., :12] - y_gt[..., :12]))

# TODO: write loss function for the case where observation contains joints plus velocities and gravity

def train_criterion(y_pred, y_gt, mode="joint_only") -> Dict[str, torch.Tensor]:
    """
    Loss function for training.
    args:
        y_pred: predicted output
        y_gt: ground truth output
        mode: "joint_only" or "joint_plus_base"
    """
    if mode == "joint_only":
        joint_loss = nn.functional.smooth_l1_loss(y_pred, y_gt, beta=0.001)
        loss = joint_loss
        return {"joint_loss": joint_loss, "loss": loss}
    elif mode == "joint_plus_base":
        joint_loss = nn.functional.smooth_l1_loss(y_pred[..., :12], y_gt[..., :12], beta=0.001)
        base_lin_vel_loss = nn.functional.smooth_l1_loss(y_pred[..., 12:15], y_gt[..., 12:15], beta=0.001)
        base_ang_vel_loss = nn.functional.smooth_l1_loss(y_pred[..., 15:18], y_gt[..., 15:18], beta=0.001)
        loss = joint_loss + base_lin_vel_loss + base_ang_vel_loss
        return {"joint_loss": joint_loss, "base_lin_vel_loss": base_lin_vel_loss, "base_ang_vel_loss": base_ang_vel_loss, "loss": loss}
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")



def make_model_input_output(obs_tensor, actions_tensor, device="cuda", noise_magnitudes=None, mode="joint_only"):
    """
    Make model input from the output of dataloader.
    args:
        obs_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_obs]
        actions_tensor: [batch_size, input_timesteps+out_timesteps, total_dim_actions]
        device: device to use
        noise_magnitudes: noise magnitudes for the model input
        mode: "joint_only" or "joint_plus_base"
    """
    # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
    # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]
    if mode == "joint_only":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    elif mode == "joint_plus_base":
        obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, 1:]
        Y = obs_tensor[:, -1] - obs_tensor[:, -2]
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")
    if noise_magnitudes is not None:
        assert noise_magnitudes.shape[0] == obs_i.shape[-1], "Noise magnitude should match the output dimension"
        noise = (torch.rand_like(obs_i) * 2 - 1)*noise_magnitudes.unsqueeze(0).unsqueeze(0)
        obs_i = obs_i + noise
    return obs_i.to(device), act_i.to(device), Y.to(device)



# Training Function
def train(model, train_loader, val_loader, optimizer, epochs, run_name, noise_magnitudes, mode):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
        noise_magnitudes: noise magnitudes for the model input
        mode: "joint_only" or "joint_plus_base"
    """

    assert mode in ["joint_only", "joint_plus_base"], "Mode should be either 'joint_only' or 'joint_plus_base'"

    wandb.init(project="dynamics-mlp", name=run_name)

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        epoch_dist = 0.0
        for obs_tensor, actions_tensor in tqdm(train_loader):
            obs_i, act_i, Y = make_model_input_output(obs_tensor, actions_tensor, noise_magnitudes = noise_magnitudes, mode=mode)
            optimizer.zero_grad()
            predictions = model(obs_i, act_i)
            loss_dict = train_criterion(predictions, Y, mode=mode)
            dist = mean_absolute_distance_error_joint_space(predictions, Y)
            dist_between_consecutive_gt_frames = torch.mean(torch.abs(Y[..., :12]))

            loss = loss_dict["loss"]
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss, "dist": dist, "dist_between_consecutive_gt_frames": dist_between_consecutive_gt_frames}
            logging_dict.update(loss_dict)
            wandb.log(logging_dict)

            epoch_loss += loss.item()
            epoch_dist += dist.item()

        avg_loss = epoch_loss / len(train_loader)
        avg_dist = epoch_dist / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for obs_tensor, actions_tensor in tqdm(val_loader):
                obs_i, act_i, Y = make_model_input_output(obs_tensor, actions_tensor, mode=mode)
                optimizer.zero_grad()
                predictions = model(obs_i, act_i)
                loss_dict = train_criterion(predictions, Y, mode=mode)
                loss = loss_dict["loss"]
                dist = mean_absolute_distance_error_joint_space(predictions, Y)
                logging_dict = {"epoch": epoch+1, "val_loss": loss, "val_dist": dist}
                logging_dict.update(loss_dict)
                wandb.log(logging_dict)
                val_loss += loss.item()
                val_dist += dist.item()

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dist = val_dist / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss, "dist_train_epoch": avg_dist, "val_dist": avg_val_dist})
        torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")

    wandb.finish()

# Main Function with Argument Parsing
def main():
    ds = DynamicsDataset(args.dataset_root_dir)
    generator = torch.Generator().manual_seed(42)

    total_number_of_samples = len(ds)
    print(f"Total number of samples in the dataset: {total_number_of_samples}")

    ratio_of_training_samples = max(args.training_samples_number / total_number_of_samples, 1.0)
    if ratio_of_training_samples < 1.0:
        ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples], generator=generator) 
    else:
        ds_subset = ds
    # ds_subset = ds
    print(f"Training and validation subset size: {len(ds_subset)}")

    train_dataset, val_dataset = random_split(ds_subset, [0.9, 0.1], generator=generator)
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    # run_name = f"hiddendim={args.hidden_dim}_bs={args.batch_size}_datasize={len(X)}"
    run_name = args.run_name

    if args.dynamics_mode == "joint_only":
        noise_magnitudes = torch.tensor([0.01]*12)
        model = resolve_pretrained_module(DynamicsSubmoduleConfig(
                                            input_dim_states=12, 
                                            input_dim_actions=12, 
                                            input_timesteps=5,
                                            output_dim=12, 
                                            hidden_dims=[512, 256, 128], 
                                            representation_dim=64, 
                                            input_slice_states=[12, 72], 
                                            input_slice_actions=[84, 132], 
                                            input_slice_policy=[270, 318], 
                                            backbone_output_dim=256,
                                            ))
    elif args.dynamics_mode == "joint_plus_base":
        # noise_magnitudes = None
        noise_magnitudes = torch.tensor([0.01]*12 + [0.1]*3 + [0.2]*3 + [0.05]*3)
        # noise_magnitudes = None
        model = resolve_pretrained_module(DynamicsSubmoduleConfig(
                                            # class_name="DynamicGRU",
                                            class_name="DynamicMLP",
                                            input_dim_states=21, 
                                            input_dim_actions=12, 
                                            input_timesteps=5,
                                            output_dim=18, 
                                            hidden_dims=[512, 256, 128], 
                                            representation_dim=64, 
                                            input_slice_states=[12, 72], # Not used during pretraining so can be arbitrary
                                            input_slice_actions=[84, 132], # Not used during pretraining so can be arbitrary
                                            input_slice_policy=[270, 318], # Not used during pretraining so can be arbitrary
                                            backbone_output_dim=256,
                                            ))
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")
    

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # Train model
    train(model, train_loader, val_loader, optimizer, args.epochs, run_name, noise_magnitudes, mode=args.dynamics_mode)

    # Save trained model
    torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
    print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

# Run script
if __name__ == "__main__":
    main()

# python ./rsl_rl/rsl_rl/addons/kinematics/train_kinematics.py --run_name kinematic_mlp_v3
"""

python ./rsl_rl/rsl_rl/addons/dynamics/train_dynamics.py \
--run_name dynamics_loco_only_initial_2M \
--dataset_root_dir logs/datasets/dynamics/velocity_vanilla_EAC_only_initial_4M 

python ./rsl_rl/rsl_rl/addons/dynamics/train_dynamics.py \
--run_name dynamics_GRU_pedi_only_initial_2M \
--dataset_root_dir logs/datasets/dynamics/pedipulation_vanilla_EAC_new_all_subset/ \
--headless

python ./rsl_rl/rsl_rl/addons/dynamics/train_dynamics.py \
--run_name dynamics_rel_pedi_only_init \
--dataset_root_dir logs/datasets/dynamics_rel/pedi_vanilla_EAC_only_initial \
--headless

python ./rsl_rl/rsl_rl/addons/dynamics/train_dynamics.py \
--run_name MLP_dynamics_rel_pedi_only_init \
--dataset_root_dir logs/datasets/dynamics_rel/pedi_vanilla_EAC_only_initial \
--headless

"""