from rsl_rl.addons.invdynamics.inv_dynamics_utils import *
from typing import List, Dict
from rsl_rl.addons.invdynamics.samples_visualization import InvSamplesVisualization, sns_jointplot
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP, InvDynamicsTransformer
from torch.utils.data import random_split
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

# NOT TESTED YET


def cal_weight_distance(param1: torch.Tensor, param2: torch.Tensor) -> float:
    """Calculate the distance between two sets of parameters."""
    return torch.norm(param1 - param2).item()

def cal_model_weights_dist(init_model: InvDynamicsMLP, trained_model: InvDynamicsMLP):
    dist_dict: Dict[str, float] = {}
    module_names = ["state_encoder", "action_encoder", "main_MLP", "output_layer"]
    for module_name in module_names:
        init_module = getattr(init_model, module_name)
        trained_module = getattr(trained_model, module_name)
        for idx, (i_layer, t_layer) in enumerate(zip(init_module, trained_module)):
            if isinstance(i_layer, torch.nn.Linear):
                dist = cal_weight_distance(i_layer.weight, t_layer.weight)
                dist_dict[f"{module_name}.{i_layer}"] = dist
    return dist_dict

def compare_updated_distances(inv_dynamics_cfg: dict, model1_init_path: str, model1_trained_path: str, model2_init_path: str, model2_trained_path: str,):
    model1_init: InvDynamicsMLP = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)
    model1_init.load_state_dict(torch.load(model1_init_path))

    model1_trained = InvDynamicsMLP(device="cuda", **inv_dynamics_cfg)
    model1_trained.load_state_dict(torch.load(model1_trained_path))

    model2_init: InvDynamicsMLP = eval(inv_dynamics_cfg["class_name"])(device="cuda", **inv_dynamics_cfg)
    model2_init.load_state_dict(torch.load(model2_init_path))

    model2_trained = InvDynamicsMLP(device="cuda", **inv_dynamics_cfg)
    model2_trained.load_state_dict(torch.load(model2_trained_path))

    model1_init_dist = cal_model_weights_dist(model1_init, model1_trained)
    model2_init_dist = cal_model_weights_dist(model2_init, model2_trained)

    # Plot the distances comparison
    import matplotlib.pyplot as plt
    import pandas as pd
    df = pd.DataFrame({
        "Module": list(model1_init_dist.keys()),
        "Model 1 Distance": list(model1_init_dist.values()),
        "Model 2 Distance": list(model2_init_dist.values())
    })

    df.set_index("Module", inplace=True)
    df.plot(kind='bar', figsize=(12, 6))
    plt.title("Weight Distances Comparison")
    plt.ylabel("Distance")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("logs/analysis/plots/weight_distances_comparison.png", dpi=300)

def main():
    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 33,  # 33 + 9 (contact booleans)
    "dim_actions": 12,
    "representation_dim": 256,
    "hidden_dims": [512, 256, 128],
    # "mode": "inv",
    "mode": "fwd",
    "lstm_core": False,  # True for LSTM, False for MLP
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }

    model1_init_path = "logs/pretrain/lightning/inv_dynamics_mlp_init.ckpt"
    model1_trained_path = "logs/pretrain/lightning/inv_dynamics_mlp_trained.ckpt"
    model2_init_path = "logs/pretrain/lightning/inv_dynamics_mlp2_init.ckpt"
    model2_trained_path = "logs/pretrain/lightning/inv_dynamics_mlp2_trained.ckpt"  

    compare_updated_distances(inv_dynamics_cfg, model1_init_path, model1_trained_path, model2_init_path, model2_trained_path)