"""Launch Isaac Sim Simulator first. We don't need it here, but it's necessary to avoid import errors."""


import argparse

from isaaclab.app import AppLauncher

# add argparse arguments
parser = argparse.ArgumentParser(description="This script demonstrates how to use the concept of an Environment.")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to spawn.")
parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
parser.add_argument("--training_samples_number", type=int, default=10000, help="Total number of training samples. However, those samples will be split into train and val datasets with ratio 9:1.")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
# make headless mode the default true
# parser.add_argument("--record_supporting_point", action="store_true", default=False, help="NOT IMPLEMENTED.")

# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args = parser.parse_args()

# launch omniverse app
app_launcher = AppLauncher(args)
simulation_app = app_launcher.app

"""Rest everything follows."""

import os
import torch
import torch.nn as nn
import torch.optim as optim
import h5py
import wandb
import argparse
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from rsl_rl.rsl_rl.modules.actor_critic import ActorCriticForAnalysis
from einops import rearrange, repeat
from rsl_rl.rsl_rl.addons.dynamics_analysis.data_utils import DynamicsAnalysisDataset
from typing import Dict
from p4rl.rsl_rl.rl_cfg import (
    RslRlPpoActorCriticForAnalysisCfg
)
from scipy.stats import norm
import glob
import pickle
import re
from tensordict import TensorDict


def extract_number(s):
    match = re.search(r"(\d+)", s)
    return int(match.group(1)) if match else None


def get_target(s_t_plus_1, s_t):
    """
    Return the change of joint positions between two time steps.
    """
    samples = s_t_plus_1 - s_t
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(9, 21))
    samples = samples[..., indices]
    return samples

def get_first_order_extrapolation(s_t):
    samples = s_t
    # indices = list(range(0, 6)) + list(range(9, 21))
    indices = list(range(21, 33))
    velo = samples[..., indices]
    first_order_etplt = velo * 0.005 # 200Hz  
    return first_order_etplt


def subset_loss_l1(mu_pred, s_t_plus_1, s_t, beta=0.002):
    """
    Compute the subset loss via smooth L1 loss.
    args:
        mu_pred: predicted mean, [num_samples, dim]
        s_t_plus_1 and s_t: both shaped [num_samples, dim]
    """
    mu_gt = get_target(s_t_plus_1, s_t) # [num_subsets, num_samples, dim]
    # loss = nn.functional.smooth_l1_loss(mu_pred, mu_gt, beta=beta)
    loss = nn.functional.l1_loss(mu_pred, mu_gt, reduction='mean')
    return loss


def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) for b in batch]
    return batch.to(device)

# Training Function
def train(model: ActorCriticForAnalysis, train_loader, val_loader, optimizer, epochs, run_name):
    """
    Train the model.
    args:
        model: the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        criterion: loss function
        optimizer: optimizer
        epochs: number of epochs to train
        run_name: name of the run for logging
    """

    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        for batch in tqdm(train_loader):
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            optimizer.zero_grad()
            mu_pred = model.get_dynamic_predictions(s_t)
            loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
            loss.backward()
            optimizer.step()
            logging_dict = {"epoch": epoch+1, "loss": loss}
            wandb.log(logging_dict)

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dist = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                batch = to_device(batch, "cuda")
                s_t, a_t, s_t_plus_1 = batch
                mu_pred = model.get_dynamic_predictions(s_t)
                loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
                logging_dict = {"epoch": epoch+1, "val_loss": loss}
                wandb.log(logging_dict)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Log training and validation loss
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        wandb.log({"epoch": epoch + 1, "loss_train_epoch": avg_loss, "val_epoch_loss": avg_val_loss})

        # training too fast, does not need to save model every epoch
        # torch.save(model.state_dict(), model_save_dir + "/" + run_name + f"_epoch_{epoch+1}.pt")

def evaluate(model: ActorCriticForAnalysis, train_loader):
    """
    Evaluate the model on the train set.
    args:
        model: the model to evaluate
        val_loader: DataLoader for train data
    """
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(train_loader):
            batch = to_device(batch, "cuda")
            s_t, a_t, s_t_plus_1 = batch
            mu_pred = model.get_dynamic_predictions(s_t)
            loss = subset_loss_l1(mu_pred, s_t_plus_1, s_t)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(train_loader)
    print(f"Train Loss: {avg_val_loss:.4f}")
    return avg_val_loss

# Main Function with Argument Parsing
def main(get_fitting_errors = True, cal_gt_magnitude=False):

    assert get_fitting_errors or cal_gt_magnitude, "At least one of get_fitting_errors or cal_gt_magnitude must be True."

    torch.manual_seed(24)

    model_save_dir = "./logs/pretrain/dynamics_analysis/series_exp_rebuttal"

    # make directory for saving models
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    dir_path = "p4rl_assets/dynamics_analysis_base_models/rebuttal"

    # initialize a nested dictionary to store val values
    # analysis_results = {
    #    "task_name_1/subfolder_name_1": {
    #    "layer_0": {
    #         "model_it_0": [run_1, run2, ...],   
    #         "model_it_1000": [],
    #         ...
    #     },
    #     "layer_1": {
    #         "model_it_0": [],
    #         "model_it_1000": [],
    #         ...
    #     },
    #     ...
    #   },
    #   "task_name_2/subfolder_name_2": {
    #       ...
    # }    
    #}

    tasks = os.listdir(dir_path)
    analysis_results: Dict[str, Dict[str, Dict[str, list]]] = {}

    run_counter = 0

    with open("p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results.pkl", "wb") as f:
        pickle.dump(analysis_results, f)

    for task_name in tasks:
        task_results = {"layer_0": {}, "layer_1": {}, "layer_2": {}, "layer_-1": {},}
        if cal_gt_magnitude:
            task_results["gt_magnitude"] = {}
            task_results["first_order_extrapolation_error"] = {}
        task_path = os.path.join(dir_path, task_name)
        # list runs
        runs = os.listdir(task_path)
        for run in runs:
            run_path = os.path.join(task_path, run)
            h5_paths = glob.glob(f"{run_path}/*.h5", recursive=False)
            for j in range(-1,3):
                for h5_path in h5_paths:
                    model_checkpoint_num = extract_number(os.path.basename(h5_path))
                    if get_fitting_errors:
                        

                        ds = DynamicsAnalysisDataset(h5_path)
                        total_number_of_samples = len(ds)
                        print(f"Total number of samples in the dataset: {total_number_of_samples}")

                        ratio_of_training_samples = args.training_samples_number / total_number_of_samples

                        ds_subset, _ = random_split(ds, [ratio_of_training_samples, 1-ratio_of_training_samples]) 
                        # ds_subset = ds
                        print(f"Total samples for training and validation combined: {len(ds_subset)}")

                        train_dataset, val_dataset = random_split(ds_subset, [1.0/1.1, 1-1.0/1.1])
                        # Create DataLoaders
                        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
                        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)

                        ###########################################################################################################
                        # run_name = f"NEW_model_it_{model_nums[i]}_data_it_{data_nums[i]}_10k_samples_input_{j}"
                        run_name = h5_path.split("/")[-3] +"_" + h5_path.split("/")[-2] + "_" + h5_path.split("/")[-1].replace(".h5", f"_layer_{j}")

                        # Initialize wandb and log configurations
                        wandb.init(project="dynamics_analysis_with_raw_obs_Dec2", name=run_name)
                        # Log argparser arguments
                        wandb.config.update(vars(args))  # Log all arguments from argparse

                        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

                        policy_cfg = RslRlPpoActorCriticForAnalysisCfg(
                            init_noise_std=1.0,
                            actor_hidden_dims=[256, 256, 256],
                            critic_hidden_dims=[256, 256, 256],
                            activation="elu",
                            layer_to_dynamics=j,
                            dim_dynamics_hidden=64,
                            dim_dynamics_prediction=12,
                            
                        )

                        # Log model configuration
                        wandb.config.update(policy_cfg.to_dict())  # Log the ActorCriticForAnalysis configuration

                        obs = TensorDict({"policy": torch.zeros((1, 48))})  # dummy obs
                        obs_groups = {"policy": ["policy"], "critic": ["policy"]}  # default groups
                        model = ActorCriticForAnalysis(
                            obs,
                            obs_groups,
                            num_actions=12,
                            **policy_cfg.to_dict()
                            ).to(device)

                        # load_model_path = os.path.join(args.load_model_dir, f"model_{model_nums[i]}.pt")
                        load_model_name = os.path.basename(h5_path).replace("data", "model").replace(".h5", ".pt")
                        load_model_path = os.path.join(os.path.dirname(h5_path), load_model_name)
                        # Load weights
                        model.load_trunk(load_model_path)

                        optimizer = optim.Adam(model.parameters(), lr=args.lr)
                        # Train model
                        train(model, train_loader, val_loader, optimizer, args.epochs, run_name)

                        # Save trained model
                        torch.save(model.state_dict(), model_save_dir + "/" + run_name + ".pt")
                        print(f"Model saved to "+ model_save_dir + "/" + run_name + ".pt")

                        wandb.finish()
                        # evaluate model and store val loss
                        val_loss = evaluate(model, train_loader)
                        if model_checkpoint_num in task_results[f"layer_{j}"]:
                            task_results[f"layer_{j}"][model_checkpoint_num].append(val_loss)
                        else:
                            task_results[f"layer_{j}"][model_checkpoint_num] = [val_loss, ]

                    if j==0 and cal_gt_magnitude:
                        ds = DynamicsAnalysisDataset(h5_path)
                        loader = DataLoader(ds, batch_size=len(ds), shuffle=False, num_workers=4)
                        s_t, a_t, s_t_plus_1 = next(iter(loader))
                        gt_magnitude = torch.abs(get_target(s_t_plus_1, s_t)).mean().item()
                        first_order_extrapolation_error = torch.abs(get_target(s_t_plus_1, s_t)-get_first_order_extrapolation(s_t)).mean().item()
                        if model_checkpoint_num in task_results["gt_magnitude"]:
                            task_results["gt_magnitude"][model_checkpoint_num].append(gt_magnitude)
                            task_results["first_order_extrapolation_error"][model_checkpoint_num].append(first_order_extrapolation_error)
                        else:
                            task_results["gt_magnitude"][model_checkpoint_num] = [gt_magnitude, ]
                            task_results["first_order_extrapolation_error"][model_checkpoint_num] = [first_order_extrapolation_error, ]

                        print(f"GT magnitude: {gt_magnitude}, First-order extrapolation error: {first_order_extrapolation_error}")

                    run_counter += 1
                    print("*" * 50)
                    print(f"Completed runs: {run_counter}. Estimated total runs: ~{len(tasks)*len(runs)*3*len(h5_paths)}")
                    print("*" * 50)

        # delete empty entries in task_results
        task_results = {k: v for k, v in task_results.items() if v}
                        
        analysis_results[task_name] = task_results
        # save analysis_results to a pickle file
        with open("p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results.pkl", "wb") as f:
            pickle.dump(analysis_results, f)


# Run script
if __name__ == "__main__":
    # main()
    main(get_fitting_errors=True, cal_gt_magnitude=True)


"""

python ./rsl_rl/rsl_rl/addons/dynamics_analysis/train_dynamics_predition_layer_series_exp.py \
--headless \
--num_envs 1

"""