
import torch
import numpy as np
import matplotlib.pyplot as plt
from rsl_rl.addons.invdynamics.inv_dynamics_sequence_utils import *
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset, INVLightningModule
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from torch.utils.data import Dataset, RandomSampler, DataLoader
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from typing import Literal
import re


def plot_gt_with_errors(gt_abs, error, num_bins=20, fig_name="logs/analysis/plots/dyna_model_error_detail.png"):
    """
    Plots a histogram of `gt` and overlays line plots of mean normalized_error and mean error.
    
    Parameters:
    - gt (array): Ground truth values, a 1D array of shape [D,].
    - normalized_error (array): Normalized error values, a 1D array of shape [D,].
    - error (array): Error values, a 1D array of shape [D,].
    - num_bins (int): The number of bins to use for the histogram (default is 10).
    """

    # gt_abs = np.array(gt_abs)
    # normalized_error = np.array(normalized_error)
    # error = np.array(error)
    normalized_error = np.abs(error/gt_abs)
    
    # Compute the histogram of gt
    counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 2.0))
    # counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 0.35))


    # Compute the bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Prepare arrays to hold the mean errors per bin
    mean_normalized_error = np.zeros_like(bin_centers)
    mean_error = np.zeros_like(bin_centers)

    # Compute mean normalized error and mean error per bin
    for i in range(len(bin_edges) - 1):
        # Find indices of gt that fall into the current bin
        mask = (gt_abs >= bin_edges[i]) & (gt_abs < bin_edges[i + 1])
        if np.any(mask):
            mean_normalized_error[i] = np.mean(normalized_error[mask])
            mean_error[i] = np.mean(error[mask])
        else:
            mean_normalized_error[i] = np.nan
            mean_error[i] = np.nan

    # Plot the histogram
    fig, ax1 = plt.subplots()

    # Histogram on the left y-axis
    color = 'tab:blue'
    ax1.set_xlabel('Absolute GT Value')
    ax1.set_ylabel('Count', color=color)
    ax1.bar(bin_centers, counts, width=np.diff(bin_edges), color=color, alpha=0.6, label='GT Histogram')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_ylim(0, 1000)

    # Line plots on a secondary y-axis
    ax2 = ax1.twinx()

    color1 = 'navy'
    color2 = 'tab:green'
    ax2.set_ylabel('Normalized Error (%)', color=color1)
    ax2.tick_params(axis='y', labelcolor=color1)
    # ax2.set_ylim(0, np.clip(max(np.nanmax(mean_normalized_error), np.nanmax(mean_error)) * 1.1, 0, 200))  # Set y-limits to be slightly above the max values
    ax2.set_ylim(0, 200)

    ax2.plot(bin_centers, mean_normalized_error*100, color=color1, marker='o', label='PIDM Mean Normalized Error')
    # ax2.plot(bin_centers, mean_error, color=color2, marker='x', label='Mean Error (radian)')
    ax2.axhline(y=40, color='purple', linestyle='--', label='40% Normalized Error') 
    ax2.axhline(y=50, color='blue', linestyle='--', label='50% Normalized Error') 

    ax2.set_xlim(0, 2.0)

    # Add legends
    fig.tight_layout()
    fig.legend(loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    # Show the plot
    # plt.title(fig_name.split("/")[-1].split(".")[0])
    plt.grid(True)
    os.makedirs("/".join(fig_name.split("/")[:-1]), exist_ok=True)
    plt.savefig(fig_name)
    print(f"Figure saved to {fig_name}")




def plot_gt_histogram(gt_abs_flat, gt_abs_rough, num_bins=20, fig_name="logs/analysis/plots/dyna_model_error_detail.png"):
    """
    Plots a histogram of `gt` and overlays line plots of mean normalized_error and mean error.
    
    Parameters:
    - gt (array): Ground truth values, a 1D array of shape [D,].
    - normalized_error (array): Normalized error values, a 1D array of shape [D,].
    - error (array): Error values, a 1D array of shape [D,].
    - num_bins (int): The number of bins to use for the histogram (default is 10).
    """

    # gt_abs = np.array(gt_abs)
    # normalized_error = np.array(normalized_error)
    # error = np.array(error)
    
    # Compute the histogram of gt
    counts, bin_edges = np.histogram(gt_abs_flat, bins=num_bins, range=(0, 2.0))
    counts_rough, bin_edges = np.histogram(gt_abs_rough, bins=num_bins, range=(0, 2.0))
    # counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 0.35))

    # Compute the bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Plot the histogram
    fig, ax1 = plt.subplots()

    # Histogram on the left y-axis
    fs = 16  # desired font size

    color = 'tab:blue'

    # Set axis labels with font size
    ax1.set_xlabel('Action Magnitude ($|a_t-q_t|$, rad)', fontsize=fs)
    ax1.set_ylabel('Num. Action Sample', fontsize=fs)

    width = 0.5*np.diff(bin_edges)

    # Plot bars
    ax1.bar(bin_centers-0.5*width, counts, width=width, color=color, alpha=1, label='Exploration (Flat)')
    ax1.bar(bin_centers+0.5*width, counts_rough, width=width, color='orange', alpha=1, label='Exploration (Rough)')

    # Set tick parameters with font size
    ax1.tick_params(axis='x', labelsize=fs)
    ax1.tick_params(axis='y', labelsize=fs)

    # Set y-axis limit
    ax1.set_ylim(0, 1000)
    ax1.set_xlim(0, 2.0)

    # Add legend with matching font size
    ax1.legend(fontsize=fs)

    # Add legends
    fig.tight_layout()
    # fig.legend(loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    # Show the plot
    # plt.title(fig_name.split("/")[-1].split(".")[0])
    plt.grid(True)
    os.makedirs("/".join(fig_name.split("/")[:-1]), exist_ok=True)
    plt.savefig(fig_name)
    print(f"Figure saved to {fig_name}")


def plot_absolute_and_normalized_errors(gt_abs, error, num_bins=20, fig_name="logs/analysis/plots/dyna_model_error_detail.png"):
    """
    Plots a histogram of `gt` and overlays line plots of mean normalized_error and mean error.
    
    Parameters:
    - gt (array): Ground truth values, a 1D array of shape [D,].
    - normalized_error (array): Normalized error values, a 1D array of shape [D,].
    - error (array): Error values, a 1D array of shape [D,].
    - num_bins (int): The number of bins to use for the histogram (default is 10).
    """

    normalized_error = np.abs(error/gt_abs)
    
    # Compute the histogram of gt
    counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 2.0))
    # counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 0.35))


    # Compute the bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Prepare arrays to hold the mean errors per bin
    mean_normalized_error = np.zeros_like(bin_centers)
    mean_error = np.zeros_like(bin_centers)

    # Compute mean normalized error and mean error per bin
    for i in range(len(bin_edges) - 1):
        # Find indices of gt that fall into the current bin
        mask = (gt_abs >= bin_edges[i]) & (gt_abs < bin_edges[i + 1])
        if np.any(mask):
            mean_normalized_error[i] = np.mean(normalized_error[mask])
            mean_error[i] = np.mean(error[mask])
        else:
            mean_normalized_error[i] = np.nan
            mean_error[i] = np.nan
    
    # Plot the histogram
    fig, ax1 = plt.subplots()

    # Histogram on the left y-axis
    fs = 16
    color = 'tab:blue'
    ax1.set_xlabel('Action Magnitude ($|a_t-q_t|$, rad)', fontsize=fs)
    ax1.set_ylabel('Absolute Error (rad)', color=color, fontsize=fs)
    ax1.tick_params(axis='y', labelcolor=color, labelsize=fs)
    ax1.tick_params(axis='x', labelsize=fs)
    ax1.plot(bin_centers, mean_error, color=color, marker='x')
    ax1.set_ylim(0, 1.7)


    # Line plots on a secondary y-axis
    ax2 = ax1.twinx()

    color2 = 'tab:green'
    ax2.set_ylabel('Normalized Error (%)', color=color2, fontsize=fs)
    ax2.tick_params(axis='y', labelcolor=color2, labelsize=fs)
    # ax2.set_ylim(0, np.clip(max(np.nanmax(mean_normalized_error), np.nanmax(mean_error)) * 1.1, 0, 200))  # Set y-limits to be slightly above the max values
    ax2.set_ylim(0, 200)

    ax2.plot(bin_centers, mean_normalized_error*100, color=color2, marker='o')
    # ax2.axhline(y=40, color='purple', linestyle='--', label='40% Normalized Error') 
    ax2.axhline(y=40, color='blue', linestyle='--',) 

    ax2.set_xlim(0, 2.0)

    # Add legends
    fig.tight_layout()
    # fig.legend(loc='upper center', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    # Show the plot
    # plt.title(fig_name.split("/")[-1].split(".")[0])
    plt.grid(True)
    os.makedirs("/".join(fig_name.split("/")[:-1]), exist_ok=True)
    plt.savefig(fig_name)
    print(f"Figure saved to {fig_name}")





def validate_plot_model_rnn(ckpt_path, dataset_path, vis_traj_num = 500):
    dataset = INVSequenceDataset(h5_path=dataset_path)
    sampler = RandomSampler(dataset, replacement=False, num_samples=len(dataset))
    dataloader = DataLoader(dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler, collate_fn=inv_dynamics_sequence_collate_fn)
    l_model = P4RLSequenceLightningModule.load_from_checkpoint(ckpt_path)
    
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)))
    plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(),  fig_name=f"logs/analysis/plots/new_model_error_detail.png")


def validate_plot_model_mlp(ckpt_path, dataset_path, vis_traj_num = 500):
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5)
    torch.manual_seed(42)
    train_dataset, validation_dataset = random_split(dataset, [0.9, 0.1])

    # sampler_train = RandomSampler(train_dataset, replacement=False, num_samples=len(train_dataset))
    # dataloader = DataLoader(train_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_train)

    sampler_val = RandomSampler(validation_dataset, replacement=False, num_samples=len(validation_dataset))
    dataloader = DataLoader(validation_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_val)

    l_model = INVLightningModule.load_from_checkpoint(ckpt_path)
    l_model.model.load_state_dict(torch.load("p4rl_assets/inv_dynamics_new/absolute_0808_normalized_loss.pt"))
    
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)))
    plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(), fig_name=f"logs/analysis/plots/new_model_error_detail.png")


def load_pretrained_model_from_RL_model(model, weights_path: str, actor_or_critic: Literal["actor", "critic"] = "critic"):
    """
    Load the pretrained module weights from a saved RL model.
    """
    state_dict = torch.load(weights_path)["model_state_dict"]
    pm_state_dict = {k.removeprefix(actor_or_critic+".pretrained_module."): v for k, v in state_dict.items() if k.startswith(actor_or_critic+".pretrained_module.")}
    model.load_state_dict(pm_state_dict)
    return model


def get_plot_data(pt_path, dataset_name, vis_traj_num = 500, dim_states=33, dim_actions=12, embodiment="anymal"):
    dataset_path = dataset_paths[dataset_name]
    # dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5)
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5, load_into_memory=False)
    torch.manual_seed(0)
    train_dataset, validation_dataset = random_split(dataset, [0.9, 0.1])

    # sampler_train = RandomSampler(train_dataset, replacement=False, num_samples=len(train_dataset))
    # dataloader = DataLoader(train_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_train)

    sampler_val = RandomSampler(validation_dataset, replacement=False, num_samples=len(validation_dataset))
    dataloader = DataLoader(validation_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_val)


    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": dim_states, 
    "dim_actions": dim_actions,
    "representation_dim": 256,
    "hidden_dims": [512, 256, 128],
    "mode": "inv",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }

    model = InvDynamicsMLP(device="cpu", **inv_dynamics_cfg)
    l_model = INVLightningModule(model=model, mode=model.mode, lr=0.1)
    l_model.model.load_state_dict(torch.load(pt_path))

    # only for debugging and tampering
    # load_pretrained_model_from_RL_model(l_model.model, 
    #                                     weights_path="logs/from_cluster/critic_investigation/inv/burn_in_unfrozen/model_50.pt", 
    #                                     actor_or_critic="critic")
    
    # symmetry can be turned on or off

    # losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)), symmetry_left_right_transform=True)
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)), 
                                                              symmetry_left_right_transform=False,
                                                              embodiment=embodiment)

    # cleaned_filename = re.sub(r"\d", "", dataset_name.replace(" Absolute", "")).replace(" ", "_")
    # cleaned_filename = dataset_name.replace(" Absolute", "").replace(" ", "_")
    # plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(), fig_name=f'logs/analysis/plots/{cleaned_filename}.pdf')
    return magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy()


def main_plots_for_paper():
    magnitude_vec_flat, losses_vec_flat = get_plot_data(
                            pt_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
                            dataset_name="Exploration Flat",
                            vis_traj_num=5000)
    magnitude_vec_rough, losses_vec_rough = get_plot_data(
                            pt_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
                            dataset_name="Exploration Rough",
                            vis_traj_num=5000)
    
    # plot_gt_histogram(magnitude_vec_flat, magnitude_vec_rough, fig_name=f"logs/analysis/plots/exploration_action_magnitude_histogram.pdf")
    plot_absolute_and_normalized_errors(np.concatenate([magnitude_vec_flat, magnitude_vec_rough]), 
                                        np.concatenate([losses_vec_flat, losses_vec_rough]), 
                                        fig_name=f"logs/analysis/plots/exploration_absolute_and_normalized_errors.pdf")
    

    

def plots_for_paper_rebuttal_sensitivity_to_accuracy():
    # epoch_str="01"
    # vis_traj_num=5000
    # pt_path = f"p4rl_assets/inv_dynamics_new/rebuttal/accuracy_sensitivity_testing_models/2025-11-21_14-58-55/epoch_00{epoch_str}.pt"
    # magnitude_vec_flat, losses_vec_flat = get_plot_data(
    #                         pt_path=pt_path,
    #                         dataset_name="Exploration Flat",
    #                         vis_traj_num=vis_traj_num)
    # magnitude_vec_rough, losses_vec_rough = get_plot_data(
    #                         pt_path=pt_path,
    #                         dataset_name="Exploration Rough",
    #                         vis_traj_num=vis_traj_num)
    
    # # plot_gt_histogram(magnitude_vec_flat, magnitude_vec_rough, fig_name=f"logs/analysis/plots/exploration_action_magnitude_histogram.pdf")
    # plot_absolute_and_normalized_errors(np.concatenate([magnitude_vec_flat, magnitude_vec_rough]), 
    #                                     np.concatenate([losses_vec_flat, losses_vec_rough]), 
    #                                     fig_name=f"logs/analysis/plots/exploration_absolute_and_normalized_errors_rebuttal_{epoch_str}.pdf")

    step_str="0100"
    vis_traj_num=5000
    pt_path = f"p4rl_assets/models_for_recording_videos/rebuttal_sensitivity/epoch_0000_step_00{step_str}.pt"
    magnitude_vec_flat, losses_vec_flat = get_plot_data(
                            pt_path=pt_path,
                            dataset_name="Exploration Flat",
                            vis_traj_num=vis_traj_num)
    magnitude_vec_rough, losses_vec_rough = get_plot_data(
                            pt_path=pt_path,
                            dataset_name="Exploration Rough",
                            vis_traj_num=vis_traj_num)
    
    # plot_gt_histogram(magnitude_vec_flat, magnitude_vec_rough, fig_name=f"logs/analysis/plots/exploration_action_magnitude_histogram.pdf")
    plot_absolute_and_normalized_errors(np.concatenate([magnitude_vec_flat, magnitude_vec_rough]), 
                                        np.concatenate([losses_vec_flat, losses_vec_rough]), 
                                        fig_name=f"logs/analysis/plots/exploration_absolute_and_normalized_errors_rebuttal_epoch00_step_{step_str}.pdf")
    


def plots_for_paper_rebuttal_g1_analysis():
    pt_path = f"p4rl_assets/inv_dynamics_new/rebuttal/g1_exploration_80_it_100_epoch_2211.pt"
    magnitude_vec_flat, losses_vec_flat = get_plot_data(
                            pt_path=pt_path,
                            dataset_name="Exploration G1 80 Iter",
                            vis_traj_num=500,
                            dim_states=83,
                            dim_actions=37,
                            embodiment="g1")
    
    # plot_gt_histogram(magnitude_vec_flat, magnitude_vec_rough, fig_name=f"logs/analysis/plots/exploration_action_magnitude_histogram.pdf")
    plot_absolute_and_normalized_errors(magnitude_vec_flat,
                                        losses_vec_flat, 
                                        fig_name=f"logs/analysis/plots/exploration_g1_rebuttal.pdf")


if __name__ == "__main__":

    main_plots_for_paper()

    # plots_for_paper_rebuttal_sensitivity_to_accuracy()

    # plots_for_paper_rebuttal_g1_analysis()
    
    # PT_validate_plot_model_mlp(
    #                         # pt_path="p4rl_assets/inv_dynamics_new/absolute_0811_pedi_output_clamped.pt", 
    #                         # pt_path="p4rl_assets/inv_dynamics_new/absolute_exploration_0813.pt",  
    #                         # pt_path="p4rl_assets/inv_dynamics_new/absolute_0831_exploration_rough.pt", 
    #                         pt_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
    #                         # pt_path="p4rl_assets/inv_dynamics_new/absolute_pedipulation_0908.pt",
  
    #                         # dataset_name = "Pedi 100 Absolute", 
    #                         dataset_name = "Exploration Rough",
    #                         # dataset_name = "Exploration Flat",
    #                         vis_traj_num=500)
    
   
    