# P4RL: this is a data sample visualization script. 

import torch
import seaborn as sns
from typing import Dict, Literal
import numpy as np
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset, INVLightningModule
import os
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from lightning import Trainer


def make_inv_input(x, a):
    x_cut = x[:, :-1]
    a_cut = a[:, 1:-1]
    delta_desired = x[:, -1, :] - x[:, -2, :] 
    a_target = a[:, -1, :]
    return x_cut, a_cut, delta_desired, a_target


def validate_model(model: InvDynamicsMLP, dataset: DynamicSlidingWindowDataset, sample_num: int = 1000):
    """
    Validate the model by checking if it can process the"accuracy_errors" dataset correctly.
    """
    assert model.mode == "inv"
    x, a = dataset.get_sample_entries_in_file(sample_num=sample_num)
    x, a = x.to(model.device), a.to(model.device)
    x_cut, a_cut, delta_desired, a_target = make_inv_input(x, a) 
    out_hat = model.forward_inv(x_cut, a_cut, delta_desired)
    accuracy_error = torch.mean(torch.abs(a_target - out_hat)).item()
    return accuracy_error

# def validate_model(model: InvDynamicsMLP, dataset: DynamicSlidingWindowDataset, sample_num: int = 1000):
#     """
#     Validate the model by checking if it can process the"accuracy_errors" dataset correctly.
#     """



#     assert model.mode == "inv"
#     val_loader = DataLoader(dataset, batch_size=20480*4, shuffle=False)
#     trainer = Trainer(accelerator="gpu", devices=1)
#     l_model = INVLightningModule(model=model, mode=model.mode, lr=0.1)
#     val_dict = trainer.validate(l_model, dataloaders=val_loader)
#     return val_dict[0]["val_error"]


def load_pretrained_model_from_RL_model(model, weights_path: str, actor_or_critic: Literal["actor", "critic"] = "critic"):
    """
    Load the pretrained module weights from a saved RL model.
    """
    state_dict = torch.load(weights_path)["model_state_dict"]
    pm_state_dict = {k.removeprefix(actor_or_critic+".pretrained_module."): v for k, v in state_dict.items() if k.startswith(actor_or_critic+".pretrained_module.")}
    model.load_state_dict(pm_state_dict)
    return model

def get_pretrained_module_RL_accuracy_curve(weights_dir: str, dataset_path: str, sample_num: int, actor_or_critic: Literal["actor", "critic"]="critic"):

    # get the dataset
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5, load_into_memory=True)

    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 33,  # 33 + 9 (contact booleans)
    "dim_actions": 12,
    "representation_dim": 256,
    "hidden_dims": [512, 256, 128],
    "mode": "inv",
    # "mode": "fwd",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }
    model: InvDynamicsMLP = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)

    pt_files_list = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
    pattern = re.compile(r"^model_\d+\.pt$")
    matching_files = [f for f in pt_files_list if pattern.match(f)] # exclude pt with irregular names like final_model.pt
    sorted_files = sorted(matching_files, key=lambda x: int(re.search(r'\d+', x).group())) # sort by iteration number

    error_list = []

    for file_name in tqdm(sorted_files):

        load_pretrained_model_from_RL_model(model, os.path.join(weights_dir, file_name), actor_or_critic)

        accuracy_error = validate_model(model, dataset, sample_num=sample_num)
        error_list.append(accuracy_error)

    return np.array(error_list)

def plot_accuracy_curve(data: Dict, save_dir: str):

    for run_name, property_dict in data.items():
        accuracy_errors = property_dict["accuracy_errors"]
        iterations = np.arange(len(accuracy_errors)) + property_dict["iteration_offset"]

        plt.plot(iterations, accuracy_errors, label=run_name)

    plt.xlabel("Iterations")
    plt.ylabel("Error on pretraining dataset")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "accuracy_curve.jpg")
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")



def plot_pretrained_RL_accuracy_evolution():

    dataset_path = dataset_paths["Pedipulation Init (Absolute, Noise)"]
    sample_num = 2000

    trails ={
    # "Critic burn in + unfrozen pretrained module":
    # {"dir_path": "logs/from_cluster/critic_investigation/inv/burn_in_unfrozen",
    #  "iteration_offset": 0,
    #  },

    # "No critic burn in + unfrozen pretrained module":
    # {"dir_path": "logs/from_cluster/critic_investigation/inv/no_burn_in_unfrozen",
    #  "iteration_offset": 100,
    #  },

    "Mirror RL w/ INV (pedi data)":{
        "dir_path": "logs/rsl_rl/pedipulation_RL_asymmetric_actor_critic/cold_start_inv_pedidata_unfrozen",
        "iteration_offset": 0,
        },
    }

    

    for label, config in trails.items():
        dir_path = config["dir_path"]
        accuracy_errors = get_pretrained_module_RL_accuracy_curve(dir_path, dataset_path, sample_num, actor_or_critic="critic")
        config["accuracy_errors"] = accuracy_errors

    plot_accuracy_curve(trails, save_dir="logs/analysis/plots/accuracy_evolution")
    

if __name__ == "__main__":
    plot_pretrained_RL_accuracy_evolution()