import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pickle

def plot_errors_dict(data_dict_task: dict, save_path: str):
    
    """
        data_dict_task: dictionary containing the data to plot
    # 
    #     {
    #    "layer_0": {
    #         0: [run_1, run2, ...],   
    #         1000: [],
    #         ...
    #     },
    #     "layer_1": {
    #         0: [],
    #         1000: [],
    #         ...
    #     },
    #     ...
    #   },
    #   
    #   
    #
    """

    np_dict = {}

    for layer_key, layer_data in data_dict_task.items():
        checkpoints_idx = sorted(layer_data.keys())
        layer_data_np = np.zeros((len(checkpoints_idx), len(layer_data[checkpoints_idx[0]])))
        for i, idx in enumerate(checkpoints_idx):
            layer_data_np[i] = np.array(layer_data[idx])
        np_dict[layer_key] = layer_data_np

    plt.figure(figsize=(8, 5))

    labels_map = {
        # "layer_-1": "Only Raw Observations",
        "layer_0": "Layer 0",
        "layer_1": "Layer 1",
        "layer_2": "Layer 2",
        "gt_magnitude": "Zero-Order Extrapolation Error",
        "first_order_extrapolation_error": "First-Order Extrapolation Error",
    }

    for key, data_np in np_dict.items():
            
        if key in labels_map:

            data_np = data_np * 180.0 / np.pi

            x = np.linspace(0, 1, data_np.shape[0])

            x =x[1:]
            data_np = data_np[1:, :]

            mean = data_np.mean(axis=1)
            std = data_np.std(axis=1)

            plt.plot(x, mean, marker='x', label=labels_map.get(key, key))
            plt.fill_between(x, mean - std, mean + std, alpha=0.2)

    # plt.title('Errors of Dynamics Prediction Based on Representations from Different Layers')
    plt.xlabel('Training Progress (Fraction of Total Iterations)')
    plt.ylabel('Degrees')

    ticks = np.linspace(0, 1, 9)  # 0, 0.125, 0.25, ..., 1.0
    plt.xticks(ticks)

    plt.xlim(0.1, 1.025)
    # plt.yscale('log')
    plt.ylim(0.6, 1.7)
    
    plt.grid(True)
    plt.legend(loc='upper right')

    # → Remove whitespace
    plt.tight_layout(pad=0)

    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))

    # → Save with minimal margins
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    print(f"Plot saved to {save_path}")



def main_separate_dicts(plot_gt=False):
    # read the dict file
    with open('p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results_9_checkpoints_raw_obs.pkl', 'rb') as f:
        data_dict = pickle.load(f)

    if plot_gt:
        with open('p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results_zero_or_first_order_extrapolation.pkl', 'rb') as f:
        # with open('p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results_gt_magnitude.pkl', 'rb') as f:
            data_dict_gt_magnitude = pickle.load(f)
        # merge two dicts
        for task, data_dict_task in data_dict.items():
            data_dict_task["gt_magnitude"] = data_dict_gt_magnitude[task]["gt_magnitude"]
            data_dict_task["first_order_extrapolation_error"] = data_dict_gt_magnitude[task]["first_order_extrapolation_error"]
        
    for task, data_dict_task in data_dict.items():
        save_path = f'logs/analysis/plots/dynamics_analysis_rebuttal_{task}.pdf'
        plot_errors_dict(data_dict_task, save_path)

    
def main_one_dict():
    # read the dict file
    # with open('p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results_0312_pedi.pkl', 'rb') as f:
    with open('p4rl_assets/dynamics_analysis_base_models/results_rebuttal/dynamics_analysis_series_rebuttal_results_0312_loco.pkl', 'rb') as f:
        data_dict = pickle.load(f)

    for task, data_dict_task in data_dict.items():
        save_path = f'logs/analysis/plots/dynamics_analysis_rebuttal_{task}.pdf'
        plot_errors_dict(data_dict_task, save_path)


if __name__ == "__main__":
    main_separate_dicts(plot_gt=True)
    # main_one_dict()