import h5py
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os

JOINT_NAMES = ['LF_HAA', 'LH_HAA', 'RF_HAA', 'RH_HAA', 'LF_HFE', 'LH_HFE', 'RF_HFE', 'RH_HFE', 'LF_KFE', 'LH_KFE', 'RF_KFE', 'RH_KFE']



# dataset_path = "logs/rsl_rl/anymal_d_flat/default_vanilla/recordings/model_299.h5" # vanilla locomotion
dataset_path = "logs/rsl_rl/pedipulation_EAC_with_pretrained_kinematics/2025-04-11_17-54-12/recordings/recordings/model_3999.h5" # pedipulation with kinematics
log_dir = Path(dataset_path).parent
# Read the dataset
with h5py.File(dataset_path, "r") as f:
    observations = f["observations"][:]
    actions = f["actions"][:]
    dones = f["dones"][:]

joint_position_target = actions
# joint_position_actual = observations[..., 12:24]
joint_position_actual = observations[..., 9:21]

# Calculate the mean error between observations and actions
mean_error_joint = np.mean(np.abs(joint_position_target - joint_position_actual), axis=(0, 1))
# Print the mean error for each joint
for i, joint_name in enumerate(JOINT_NAMES):
    print(f"{joint_name}: {mean_error_joint[i]}")
mean_error = np.mean(mean_error_joint)
print(f"Mean Error: {mean_error}")



# Function to plot comparison for a given segment index
def plot_comparison(segment_index, steps_plot = 150):
    reading_seg = joint_position_actual[segment_index, :steps_plot]
    target_segment = joint_position_target[segment_index, :steps_plot]
    dones_segment = dones[segment_index, :steps_plot]
    n_dim = reading_seg.shape[-1]

    n_row = 6
    n_col = 2
    fig, axes = plt.subplots(n_row, n_col, figsize=(20, 5 * n_dim))

    for i in range(n_dim):
        ax = axes[i // n_col][i % n_col]

        # Plot the reading and target curves
        ax.plot(reading_seg[:, i], label="reading", color="blue")
        ax.plot(target_segment[:, i], label="target", color="purple")

        # Calculate the difference curve and plot it
        difference = target_segment[:, i] - reading_seg[:, i]
        ax.plot(difference, label="difference", color="orange")

        # Fill the area between the difference curve and the x-axis
        ax.fill_between(
            range(len(difference)),
            difference,
            0,
            color="orange",
            alpha=0.3,
            label="error area",
        )

        # Set y-limit
        ax.set_ylim(-1.5, 1.5)

        # Set title and legend
        ax.set_title(JOINT_NAMES[i])
        ax.legend()

        # Add vertical red lines where dones are 1
        for t, done in enumerate(dones_segment):
            if done == 1:
                ax.axvline(x=t, color="red", linestyle="--")

    plt.tight_layout()
    plt.show()
    plt.savefig(os.path.join(log_dir, f"comparison_segment_{segment_index}.png"))

for i in range(10):
    plot_comparison(i)