"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--data", type=str, default="./logs/datasets/kine_data_v3.h5", help="Path to HDF5 dataset")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=2048, help="Batch size")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--output_feature_per_body", type=int, default=3, help="Output feature dimension, e.g. 3 for translation, 6 for pose")
parser.add_argument("--model_save_dir", type=str, default="./logs/pretrain", help="Directory to save trained model")
parser.add_argument("--run_name", type=str, default="kinematic_mlp_4_layer_out_128", help="Run name for logging")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""


import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.addons.kinematics.modules import KinematicSubmoduleConfig
from rsl_rl.addons.resolve_submodule import resolve_pretrained_module
from einops import rearrange

@torch.no_grad()
def mean_distance(y_pred, y_gt):
    """
    X, Y: both of shape [B, N, D], D may differ
    """
    y_pred = rearrange(y_pred, 'b (n d) -> b n d', n=8)
    idx = torch.tensor([3,4,7,8,11,12,15,16], dtype=torch.int, device=y_pred.device) # only regress on shanks and feets
    y_gt = y_gt[:, idx]
    return torch.mean(torch.norm(y_pred - y_gt, dim=-1))

def test_criterion(y_pred, y_gt):
    """
    X, Y: both of shape [B, N, D], D may differ
    """
    y_pred = rearrange(y_pred, 'b (n d) -> b n d', n=8)
    idx = torch.tensor([3,4,7,8,11,12,15,16], dtype=torch.int, device=y_pred.device) # only regress on shanks and feets
    return nn.functional.smooth_l1_loss(y_pred, y_gt[:, idx], beta=0.02) 


# Load Data from HDF5
def load_h5_dataset(h5_file):
    with h5py.File(h5_file, "r") as f:
        X = torch.tensor(f["X"][:], dtype=torch.float32)
        Y = torch.tensor(f["Y"][:], dtype=torch.float32)
        input_joint_names = f.attrs["input_joint_names"].split(",")
        output_body_names = f.attrs["output_body_names"].split(",")
        output_feature_dimensions = f.attrs["output_feature_dimensions"].split(",")
        print(f"Loaded dataset with {len(X)} samples. Input joint names: {input_joint_names}, Output body names: {output_body_names}, Output feature dimensions: {output_feature_dimensions}")
    return X, Y

# Training Function
def train(model, train_loader, val_loader, criterion, optimizer, epochs, run_name):
    wandb.init(project="kinematic-mlp", name=run_name)

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        epoch_dist = 0.0
        for batch_x, batch_y in tqdm(train_loader):
            optimizer.zero_grad()
            predictions = model(batch_x)
            loss = criterion(predictions, batch_y)
            dist = mean_distance(predictions, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_dist += dist.item()

        avg_loss = epoch_loss / len(train_loader)
        avg_dist = epoch_dist / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for batch_x, batch_y in tqdm(val_loader):
                predictions = model(batch_x)
                loss = criterion(predictions, batch_y)
                dist = mean_distance(predictions, batch_y)
                val_loss += loss.item()
                val_dist += dist.item()

        avg_val_loss = val_loss / len(val_loader)
        avg_val_dist = val_dist / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss": avg_loss, "val_loss": avg_val_loss, "dist": avg_dist, "val_dist": avg_val_dist})

    wandb.finish()

# Main Function with Argument Parsing
def main():
    # Load dataset
    X, Y = load_h5_dataset(args.data)

    # move to GPU
    X = X.to(torch.device("cuda"))
    Y = Y.to(torch.device("cuda"))

    # exlude the dimension of contact forces for now
    # Y = Y[..., :-1]
    # only focus on the translation part of the pose
    Y = Y[..., :args.output_feature_per_body]

    # Split dataset into training and validation sets (90% train, 10% validation)
    dataset = TensorDataset(X, Y)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # Define model, loss, and optimizer
    model = resolve_pretrained_module(KinematicSubmoduleConfig(num_output_features_per_body=args.output_feature_per_body, 
                                                               backbone_output_dim=128,), 
                                        torch.device("cuda"))
    # criterion = nn.MSELoss()
    # criterion = nn.SmoothL1Loss(beta=0.02)
    criterion = test_criterion
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # run_name = f"hiddendim={args.hidden_dim}_bs={args.batch_size}_datasize={len(X)}"
    run_name = args.run_name

    # Train model
    train(model, train_loader, val_loader, criterion, optimizer, args.epochs, run_name)

    # Save trained model
    torch.save(model.state_dict(), args.model_save_dir + "/" + run_name + ".pt")
    print(f"Model saved to "+ args.model_save_dir + "/" + run_name + ".pt")

# Run script
if __name__ == "__main__":
    main()

# python ./rsl_rl/rsl_rl/addons/kinematics/train_kinematics.py --run_name kinematic_mlp_v3