"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher
# import lightning as L

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--dataset_root_dir", type=str, default="logs/datasets/dynamics_rel/pedi_data_only_init.h5", help="Path to HDF5 dataset")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=2000000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--run_name", type=str, default="traj_gru_dynamics_rel_pedi_init", help="Run name for logging")
parser.add_argument("--dynamics_mode", type=str, default="joint_plus_base", help="Mode of dynamics model. joint_only or joint_plus_base")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

model_save_dir = "./logs/pretrain/dynamics/" + args.run_name

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.dynamics.modules import DynamicsSubmoduleConfig
from rsl_rl.addons.dynamics.modules_recurrent import DynamicGRU
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange
from rsl_rl.rsl_rl.addons.dynamics.data_utils_june import DynamicsTrajectoryDataset, collate_variable_length
from typing import Dict

@torch.no_grad()
def mean_absolute_distance_error_joint_space(y_pred, y_gt, mask):
    return torch.mean(torch.abs(y_pred[..., :12] - y_gt[..., :12]).mean(dim=-1)[mask])  # Only consider joint positions, ignore velocities and gravity

# TODO: write loss function for the case where observation contains joints plus velocities and gravity

def train_criterion(y_pred, y_gt, mask, mode="joint_only") -> Dict[str, torch.Tensor]:
    """
    Loss function for training.
    args:
        y_pred: predicted output
        y_gt: ground truth output
        mask: mask for the output, True = valid, False = pad
        mode: "joint_only" or "joint_plus_base"
    """
    if mode == "joint_only":
        joint_loss = nn.functional.smooth_l1_loss(y_pred, y_gt, beta=0.001, reduction='none').mean(dim=-1)
        loss = torch.mean(joint_loss[mask])
        return {"joint_loss": joint_loss, "loss": loss}
    elif mode == "joint_plus_base":
        joint_loss = nn.functional.smooth_l1_loss(y_pred[..., :12], y_gt[..., :12], beta=0.001, reduction='none').mean(dim=-1)  # Joint positions
        base_lin_vel_loss = nn.functional.smooth_l1_loss(y_pred[..., 12:15], y_gt[..., 12:15], beta=0.001, reduction='none').mean(dim=-1)
        base_ang_vel_loss = nn.functional.smooth_l1_loss(y_pred[..., 15:18], y_gt[..., 15:18], beta=0.001, reduction='none').mean(dim=-1)
        loss = torch.mean((joint_loss + base_lin_vel_loss + base_ang_vel_loss)[mask])
        return {"joint_loss": joint_loss, "base_lin_vel_loss": base_lin_vel_loss, "base_ang_vel_loss": base_ang_vel_loss, "loss": loss}
    else:
        raise ValueError("Mode should be either 'joint_only' or 'joint_plus_base'")



def make_model_input_output(obs_tensor, actions_tensor, masks, device="cuda", noise_magnitudes=None):
    """
    Make model input from the output of dataloader.
    args:
        obs_tensor: [batch_size, timesteps, total_dim_obs]
        actions_tensor: [batch_size, timesteps, total_dim_actions]
        device: device to use
        noise_magnitudes: noise magnitudes for the model input
    
    NOTE: different from legacy train_dynamics.py, this action tensor comes straight from the transition buffer, 
    not the "last_action" from observation tensor, so it is already aligned with the observation tensor!
    """
    # X = torch.cat([rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]], dim=-1)
    # obs_i, act_i = rearrange(obs_tensor[:, :-1, 9:21], "b t d -> b (t d)"), actions_tensor[:, -2]

    obs_i, act_i = obs_tensor[:, :-1], actions_tensor[:, :-1] # remove the last timestep because we don't have the next observation for it
    Y = obs_tensor[:, 1:] - obs_tensor[:, :-1]
    mask_i = masks[:, 1:]  # mask for the output, True = valid, False = pad

    if noise_magnitudes is not None:
        assert noise_magnitudes.shape[0] == obs_i.shape[-1], "Noise magnitude should match the output dimension"
        noise = (torch.rand_like(obs_i) * 2 - 1)*noise_magnitudes.unsqueeze(0).unsqueeze(0)
        obs_i = obs_i + noise
    return obs_i.to(device), act_i.to(device), Y.to(device), mask_i.to(device)



# Training Function
def train(model: DynamicGRU, train_loader, val_loader, optimizer, epochs, run_name, noise_magnitudes, mode):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
        noise_magnitudes: noise magnitudes for the model input
        mode: "joint_only" or "joint_plus_base"
    """

    log_interval = 100
    counter = 0

    assert mode in ["joint_only", "joint_plus_base"], "Mode should be either 'joint_only' or 'joint_plus_base'"

    wandb.init(project="dynamics-mlp", name=run_name)

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        epoch_dist = 0.0
        for obs_tensor, actions_tensor, masks, lengths in tqdm(train_loader):
            obs_i, act_i, Y, mask_i = make_model_input_output(obs_tensor, actions_tensor, masks, noise_magnitudes = noise_magnitudes)
            optimizer.zero_grad()
            predictions = model.forward_traj(obs_i, act_i) # [batch_size, out_timesteps, output_dim]
            loss_dict = train_criterion(predictions, Y, mask_i, mode=mode)
            dist = mean_absolute_distance_error_joint_space(predictions, Y, mask_i)
            dist_between_consecutive_gt_frames = torch.mean(torch.abs(Y[..., :12][mask_i]))

            loss = loss_dict["loss"]
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss, "dist": dist, "step": counter, "dist_between_consecutive_gt_frames": dist_between_consecutive_gt_frames}
            logging_dict.update(loss_dict)
            counter += 1
            if counter % log_interval == 0:
                wandb.log(logging_dict)

            epoch_loss += loss.item()
            epoch_dist += dist.item()

        avg_loss = epoch_loss / len(train_loader)
        avg_dist = epoch_dist / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for obs_tensor, actions_tensor, masks, lengths  in tqdm(val_loader):
                obs_i, act_i, Y, mask_i = make_model_input_output(obs_tensor, actions_tensor, masks)
                predictions = model.forward_traj(obs_i, act_i)
                loss_dict = train_criterion(predictions, Y, mask_i, mode=mode)
                loss = loss_dict["loss"]
                dist = mean_absolute_distance_error_joint_space(predictions, Y, mask_i)
                logging_dict = {"epoch": epoch+1, "val_loss": loss, "val_dist": dist, "step": counter}
                logging_dict.update(loss_dict)
                counter += 1
                if counter % log_interval == 0:
                    wandb.log(logging_dict)
                val_loss += loss.item()
                val_dist += dist.item()

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dist = val_dist / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss, "dist_train_epoch": avg_dist, "val_dist": avg_val_dist})
        torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")

    wandb.finish()


def main():
    ds = DynamicsTrajectoryDataset(args.dataset_root_dir)
    generator = torch.Generator().manual_seed(42)

    total_number_of_samples = len(ds)
    print(f"Total number of samples in the dataset: {total_number_of_samples}")

    ratio_of_training_samples = max(args.training_samples_number / total_number_of_samples, 1.0)
    if ratio_of_training_samples < 1.0:
        ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples], generator=generator) 
    else:
        ds_subset = ds
    # ds_subset = ds
    print(f"Training and validation subset size: {len(ds_subset)}")

    train_dataset, val_dataset = random_split(ds_subset, [0.9, 0.1], generator=generator)
    # Create DataLoaders
    num_workers = 0
    # probably because this dataset consists of one large h5 file, using two many workers will exceed the shared memory limit
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_variable_length)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_variable_length)

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    run_name = args.run_name
    # noise_magnitudes = None
    noise_magnitudes = torch.tensor([0.01]*12 + [0.1]*3 + [0.2]*3 + [0.05]*3)
    # noise_magnitudes = None
    model = resolve_pretrained_module(DynamicsSubmoduleConfig(
                                        # class_name="DynamicGRU",
                                        class_name="DynamicGRU",
                                        input_dim_states=21, 
                                        input_dim_actions=12, 
                                        input_timesteps=5,
                                        output_dim=18, 
                                        hidden_dims=[512, 256, 128], 
                                        representation_dim=64, 
                                        input_slice_states=[12, 72], # Not used during pretraining so can be arbitrary
                                        input_slice_actions=[84, 132], # Not used during pretraining so can be arbitrary
                                        input_slice_policy=[270, 318], # Not used during pretraining so can be arbitrary
                                        backbone_output_dim=256,
                                        ))

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    # Train model
    train(model, train_loader, val_loader, optimizer, args.epochs, run_name, noise_magnitudes, mode=args.dynamics_mode)

    # Save trained model
    torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
    print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

# Run script
if __name__ == "__main__":
    main()

# python ./rsl_rl/rsl_rl/addons/kinematics/train_kinematics.py --run_name kinematic_mlp_v3
"""

python ./rsl_rl/rsl_rl/addons/dynamics/train_dynamics.py \
--run_name dynamics_loco_only_initial_2M \
--dataset_root_dir logs/datasets/dynamics/velocity_vanilla_EAC_only_initial_4M 

"""