
import torch
import numpy as np
import matplotlib.pyplot as plt
from rsl_rl.addons.invdynamics.inv_dynamics_sequence_utils import *
from rsl_rl.addons.invdynamics.inv_dynamics_utils import DynamicSlidingWindowDataset, INVLightningModule
from rsl_rl.addons.invdynamics.inv_dynamics_module import InvDynamicsMLP
from torch.utils.data import Dataset, RandomSampler, DataLoader
from rsl_rl.addons.invdynamics.inv_dynamics_dataset_paths import dataset_paths
from typing import Literal
import re


def plot_gt_with_errors(gt_abs, error, num_bins=20, fig_name="logs/analysis/plots/dyna_model_error_detail.png"):
    """
    Plots a histogram of `gt` and overlays line plots of mean normalized_error and mean error.
    
    Parameters:
    - gt (array): Ground truth values, a 1D array of shape [D,].
    - normalized_error (array): Normalized error values, a 1D array of shape [D,].
    - error (array): Error values, a 1D array of shape [D,].
    - num_bins (int): The number of bins to use for the histogram (default is 10).
    """

    # gt_abs = np.array(gt_abs)
    # normalized_error = np.array(normalized_error)
    # error = np.array(error)
    normalized_error = np.abs(error/gt_abs)
    
    # Compute the histogram of gt
    counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 2.0))
    # counts, bin_edges = np.histogram(gt_abs, bins=num_bins, range=(0, 0.35))


    # Compute the bin centers
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Prepare arrays to hold the mean errors per bin
    mean_normalized_error = np.zeros_like(bin_centers)
    mean_error = np.zeros_like(bin_centers)

    # Compute mean normalized error and mean error per bin
    for i in range(len(bin_edges) - 1):
        # Find indices of gt that fall into the current bin
        mask = (gt_abs >= bin_edges[i]) & (gt_abs < bin_edges[i + 1])
        if np.any(mask):
            mean_normalized_error[i] = np.mean(normalized_error[mask])
            mean_error[i] = np.mean(error[mask])
        else:
            mean_normalized_error[i] = np.nan
            mean_error[i] = np.nan

    # Plot the histogram
    fig, ax1 = plt.subplots()

    # Histogram on the left y-axis
    color = 'tab:blue'
    ax1.set_xlabel('Absolute GT Value')
    ax1.set_ylabel('Count', color=color)
    ax1.bar(bin_centers, counts, width=np.diff(bin_edges), color=color, alpha=0.6, label='GT Histogram')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_ylim(0, 1000)

    # Line plots on a secondary y-axis
    ax2 = ax1.twinx()

    color1 = 'navy'
    color2 = 'tab:green'
    ax2.set_ylabel('Normalized Error (%)', color=color1)
    ax2.tick_params(axis='y', labelcolor=color1)
    # ax2.set_ylim(0, np.clip(max(np.nanmax(mean_normalized_error), np.nanmax(mean_error)) * 1.1, 0, 200))  # Set y-limits to be slightly above the max values
    ax2.set_ylim(0, 200)

    ax2.plot(bin_centers, mean_normalized_error*100, color=color1, marker='o', label='PIDM Mean Normalized Error')
    # ax2.plot(bin_centers, mean_error, color=color2, marker='x', label='Mean Error (radian)')
    ax2.axhline(y=40, color='purple', linestyle='--', label='40% Normalized Error') 
    ax2.axhline(y=50, color='blue', linestyle='--', label='50% Normalized Error') 

    ax2.set_xlim(0, 2.0)

    # Add legends
    fig.tight_layout()
    fig.legend(loc='upper right', bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    # Show the plot
    # plt.title(fig_name.split("/")[-1].split(".")[0])
    plt.grid(True)
    plt.savefig(fig_name)
    print(f"Figure saved to {fig_name}")



def validate_plot_model_rnn(ckpt_path, dataset_path, vis_traj_num = 500):
    dataset = INVSequenceDataset(h5_path=dataset_path)
    sampler = RandomSampler(dataset, replacement=False, num_samples=len(dataset))
    dataloader = DataLoader(dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler, collate_fn=inv_dynamics_sequence_collate_fn)
    l_model = P4RLSequenceLightningModule.load_from_checkpoint(ckpt_path)
    
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)))
    plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(),  fig_name=f"logs/analysis/plots/new_model_error_detail.png")


def validate_plot_model_mlp(ckpt_path, dataset_path, vis_traj_num = 500):
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5)
    torch.manual_seed(42)
    train_dataset, validation_dataset = random_split(dataset, [0.9, 0.1])

    # sampler_train = RandomSampler(train_dataset, replacement=False, num_samples=len(train_dataset))
    # dataloader = DataLoader(train_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_train)

    sampler_val = RandomSampler(validation_dataset, replacement=False, num_samples=len(validation_dataset))
    dataloader = DataLoader(validation_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_val)

    l_model = INVLightningModule.load_from_checkpoint(ckpt_path)
    l_model.model.load_state_dict(torch.load("p4rl_assets/inv_dynamics_new/absolute_0808_normalized_loss.pt"))
    
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)))
    plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(), fig_name=f"logs/analysis/plots/new_model_error_detail.png")


def load_pretrained_model_from_RL_model(model, weights_path: str, actor_or_critic: Literal["actor", "critic"] = "critic"):
    """
    Load the pretrained module weights from a saved RL model.
    """
    state_dict = torch.load(weights_path)["model_state_dict"]
    pm_state_dict = {k.removeprefix(actor_or_critic+".pretrained_module."): v for k, v in state_dict.items() if k.startswith(actor_or_critic+".pretrained_module.")}
    model.load_state_dict(pm_state_dict)
    return model


def PT_validate_plot_model_mlp(pt_path, dataset_name, vis_traj_num = 500):
    dataset_path = dataset_paths[dataset_name]
    # dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5)
    dataset = DynamicSlidingWindowDataset(h5_path=dataset_path, window_size=5, load_into_memory=False)
    torch.manual_seed(0)
    train_dataset, validation_dataset = random_split(dataset, [0.9, 0.1])

    # sampler_train = RandomSampler(train_dataset, replacement=False, num_samples=len(train_dataset))
    # dataloader = DataLoader(train_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_train)

    sampler_val = RandomSampler(validation_dataset, replacement=False, num_samples=len(validation_dataset))
    dataloader = DataLoader(validation_dataset, batch_size=vis_traj_num, num_workers=0, sampler=sampler_val)


    inv_dynamics_cfg = {
    "class_name": "InvDynamicsMLP",
    "dim_states": 33,  # 33 + 9 (contact booleans)
    "dim_actions": 12,
    "representation_dim": 256,
    "hidden_dims": [512, 256, 128],
    "mode": "inv",
    "lstm_core": False,  # True for LSTM, False for MLP
    # "mode": "dl",
    "activation_name": "elu", # or "siren"
    "input_timesteps": 5,
    }

    model = InvDynamicsMLP(device="cpu", **inv_dynamics_cfg)
    l_model = INVLightningModule(model=model, mode=model.mode, lr=0.1)
    l_model.model.load_state_dict(torch.load(pt_path))

    # only for debugging and tampering
    # load_pretrained_model_from_RL_model(l_model.model, 
    #                                     weights_path="logs/from_cluster/critic_investigation/inv/burn_in_unfrozen/model_50.pt", 
    #                                     actor_or_critic="critic")
    
    # symmetry can be turned on or off

    # losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)), symmetry_left_right_transform=True)
    losses_vec, magnitude_vec = l_model.validate_batch_detail(next(iter(dataloader)), symmetry_left_right_transform=False)

    # cleaned_filename = re.sub(r"\d", "", dataset_name.replace(" Absolute", "")).replace(" ", "_")
    cleaned_filename = dataset_name.replace(" Absolute", "").replace(" ", "_")
    plot_gt_with_errors(magnitude_vec.cpu().numpy(), losses_vec.cpu().numpy(), fig_name=f'logs/analysis/plots/{cleaned_filename}.pdf')


if __name__ == "__main__":
    # validate_plot_model_rnn(ckpt_path="logs/pretrain/lightning/inv_vanilla_pedi_LSTM_1layers_1024hidden.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         vis_traj_num=500)

    # validate_plot_model_rnn(ckpt_path="logs/pretrain/lightning/fwd_vanilla_pedi_LSTM_1layers_1024hidden.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         vis_traj_num=500)


    ### MLP x8 size model, 100 epochs
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/inv_vanilla_pedi_MLP_size_factor_8.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         vis_traj_num=500)
    
    ### MLP x8 size model, 200 epochs
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/inv_vanilla_pedi_MLP_size_factor_8_200epochs.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         vis_traj_num=500)
    
    
    # validate_plot_model_mlp(ckpt_path="p4rl_assets/inv_dynamics_new/timestep_5_noisy_1000_epochs.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         vis_traj_num=500)
    
    PT_validate_plot_model_mlp(
                            # pt_path="p4rl_assets/inv_dynamics_new/absolute_0811_pedi_output_clamped.pt", 
                            # pt_path="p4rl_assets/inv_dynamics_new/absolute_exploration_0813.pt",  
                            # pt_path="p4rl_assets/inv_dynamics_new/absolute_0831_exploration_rough.pt", 
                            pt_path="p4rl_assets/inv_dynamics_new/absolute_0906_mixed.pt",
                            # pt_path="p4rl_assets/inv_dynamics_new/absolute_pedipulation_0908.pt",
  
                            # dataset_name = "Pedi 100 Absolute", 
                            dataset_name = "Exploration Rough",
                            # dataset_name = "Exploration Flat",
                            vis_traj_num=500)
    
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/fwd_vanilla_pedi_MLP_input_timesteps=5.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         # dataset_path = dataset_paths["Pedipulation Expert"],
    #                         vis_traj_num=500)

    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/inv_vanilla_pedi_TF.ckpt", 
    #                         dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                         # dataset_path = dataset_paths["Pedipulation Expert"],
    #                         vis_traj_num=500)
    
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/fwd_vanilla_pedi_TF.ckpt", 
    #                     dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                     # dataset_path = dataset_paths["Pedipulation Expert"],
    #                     vis_traj_num=500)
    
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/jacobian_vanilla_pedi_MLP_size_factor_1.ckpt", 
    #                     dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                     vis_traj_num=500)
    
    # validate_plot_model_mlp(ckpt_path="logs/pretrain/lightning/dl_vanilla_pedi_MLP_size_factor_1.ckpt", 
    #                 dataset_path = dataset_paths["Pedipulation Init (no random)"], 
    #                 vis_traj_num=500)
    
    