from rsl_rl.addons.invdynamics.inv_dynamics_utils import *
from typing import List, Dict, DefaultDict
import numpy as np
import os
import torch
import re
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm


def cal_weight_distance(param1: torch.Tensor, param2: torch.Tensor) -> float:
    """Calculate the distance between two sets of parameters."""
    return (torch.mean(torch.abs(param1 - param2))).item()

def cal_weights_diff_from_state_dict(previous_model_sd: Dict[str, torch.Tensor], after_model_sd: Dict[str, torch.Tensor]) -> Dict[str, float]:
    dist_dict: Dict[str, float] = {}
    for key in previous_model_sd.keys():
        if not "weight" in key: # skip biases
            continue        
        dist = cal_weight_distance(previous_model_sd[key], after_model_sd[key])
        dist_dict[key] = dist
    return dist_dict

def get_state_dict_from_file(file: str, dir: str):
    """Get the state dict from a saved rsl_rl model file in the given directory. """
    file_path = os.path.join(dir, file)
    return torch.load(file_path)["model_state_dict"]

def get_single_RL_process_weights_update_per_iteration_dict(dir: str, first_N_iterations: None | int = None) -> Dict[str, np.array]:
    """Get the weights update per iteration from a single RL process. Return a dictionary with keys as layer 
    names and values as numpy arrays of distances.
    
    """

    pt_files_list = [f for f in os.listdir(dir) if f.endswith('.pt')]

    pattern = re.compile(r"^model_\d+\.pt$")
    matching_files = [f for f in pt_files_list if pattern.match(f)] # exclude pt with irregular names like final_model.pt
    sorted_files = sorted(matching_files, key=lambda x: int(re.search(r'\d+', x).group())) # sort by iteration number
    if first_N_iterations is not None:
        sorted_files = sorted_files[:first_N_iterations]
    sorted_indices = [int(re.search(r'\d+', f).group()) for f in sorted_files]
    intervals = np.diff(sorted_indices)  # Calculate intervals between iterations
    intervals_unique = np.unique(intervals)
    if not (intervals_unique.size == 1 and intervals_unique[0] == 1):
        print(f"Warning: The model files saved in {dir} are not consecutive. Observed intervals: {intervals_unique}. ")

    dist_dict: DefaultDict[str, List[float]] = defaultdict(list)

    for i in tqdm(range(len(sorted_files)-1)):
        weights_diff_between_consecutive = cal_weights_diff_from_state_dict(
            get_state_dict_from_file(sorted_files[i], dir), 
            get_state_dict_from_file(sorted_files[i+1], dir))
        for key, value in weights_diff_between_consecutive.items():
            dist_dict[key].append(value)

    # convert the defaultdict to a regular dict with numpy arrays
    regular_dict = {k: np.array(v) for k, v in dist_dict.items()}
    
    return regular_dict

    
def plot_weights_update_curves_compare_runs(
    data: Dict[str, Dict],
    save_dir: str
) -> None:
    """
    Plots each inner list of floats for each key in the inner dicts as a separate subplot.
    Outer dict keys are treated as run names (used in legend).
    All inner dicts must have the same keys.
    
    Parameters:
        data: Dict where keys are run names, values are dicts mapping metric names to lists of floats.
        save_path: File path to save the figure.
    """
    if not data:
        raise ValueError("`data` must not be empty")
    
    # Extract metric names from the first run
    first_run_key = next(iter(data))
    metrics = list(data[first_run_key]["curves"].keys())
    metrics = [m for m in metrics if "encoder" not in m and "output_layer" not in m]  # filter out encoder and output layer

    num_plots = len(metrics)
    columns = 4 
    rows = int(np.ceil(num_plots // columns))  # Ceiling division to determine number of rows
    
    fig, axes = plt.subplots(rows, columns, figsize=(8*columns, 4*rows))
    
    if num_plots == 1:
        axes = [axes]

    axes = np.atleast_1d(axes).ravel()
    
    for ax, metric in zip(axes, metrics):
        for run_name, property_dict in data.items():
            if metric not in property_dict["curves"]:
                raise ValueError(f"Metric '{metric}' missing in run '{run_name}'")
            start_iteration = property_dict.get("iteration_offset", 0)
            ax.plot(np.arange(start_iteration, start_iteration+len(property_dict["curves"][metric])), property_dict["curves"][metric], marker='o', label=run_name)
        ax.set_title(metric)
        ax.set_xlabel("RL Iteration")
        ax.set_ylabel("Weight Update Distance")
        ax.grid(True)
        ax.legend()
        if len(data) > 1:
            ax.legend()
    
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "weights_update_curves_compare_runs.pdf")
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")
    plt.close(fig)

def plot_weights_update_curves_compare_layers(
    data: Dict[str, Dict],
    save_dir: str
) -> None:
    """
    Plots each inner list of floats for each key in the inner dicts as a separate subplot.
    Outer dict keys are treated as run names (used in legend).
    All inner dicts must have the same keys.
    
    Parameters:
        data: Dict where keys are run names, values are dicts mapping metric names to lists of floats.
        save_path: File path to save the figure.
    """
    if not data:
        raise ValueError("`data` must not be empty")
    
    # Extract metric names from the first run
    first_run_key = next(iter(data))
    metrics = list(data[first_run_key]["curves"].keys())
    actor_layers = [k for k in metrics if "actor" in k and "encoder" not in k and "output_layer" not in k]
    critic_layers = [k for k in metrics if "critic" in k and "encoder" not in k and "output_layer" not in k]
    # num_plots = 2*len(list(data.keys()))
    num_plots = len(actor_layers) + len(critic_layers)

    columns = 4 
    rows = np.ceil(num_plots // columns)  # Ceiling division to determine number of rows
    
    
    fig, axes = plt.subplots(rows, columns, figsize=(24, 12 * num_plots))
    
    if num_plots == 1:
        axes = [axes]

    cmap = plt.get_cmap("viridis")

    colors_actor = [cmap(i / (len(actor_layers)-1)) for i in range(len(actor_layers))]
    colors_critic = [cmap(i / (len(critic_layers)-1)) for i in range(len(critic_layers))]

    
    for i, run_name in enumerate(data.keys()):
        property_dict = data[run_name]
        ax = axes[2*i]
        for metric in actor_layers:
            if metric not in property_dict["curves"]:
                raise ValueError(f"Metric '{metric}' missing in run '{run_name}'")
            start_iteration = property_dict.get("iteration_offset", 0)
            ax.plot(np.arange(start_iteration, start_iteration+len(property_dict["curves"][metric])), 
                    property_dict["curves"][metric], 
                    marker='o', 
                    label=metric,
                    linewidth=1,
                    color=colors_actor[actor_layers.index(metric)]
                    )
        ax.set_title(run_name+ " - Actor Layers")
        ax.set_xlabel("RL Iteration")
        ax.set_ylabel("Weight Update Distance")
        ax.grid(True)
        ax.legend(
            loc="center left",
            bbox_to_anchor=(1, 0.5),
            fontsize=6,
        )

        ax = axes[2*i + 1]
        for metric in critic_layers:
            if metric not in property_dict["curves"]:
                raise ValueError(f"Metric '{metric}' missing in run '{run_name}'")
            start_iteration = property_dict.get("iteration_offset", 0)
            ax.plot(np.arange(start_iteration, start_iteration+len(property_dict["curves"][metric])), 
                    property_dict["curves"][metric], 
                    marker='o', 
                    label=metric,
                    linewidth=1,
                    color=colors_critic[critic_layers.index(metric)]
                    )
        ax.set_title(run_name+ " - Critic Layers")
        ax.set_xlabel("RL Iteration")
        ax.set_ylabel("Weight Update Distance")
        ax.grid(True)
        ax.legend(
            loc="center left",
            bbox_to_anchor=(1, 0.5),
            fontsize=6,
        )

    
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "weights_update_curves_compare_layers.pdf")
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")
    plt.close(fig)



def test(trails: Dict[str, Dict], fig_save_path):

    for key, value in trails.items():
        trails[key]["curves"] = get_single_RL_process_weights_update_per_iteration_dict(value["dir_path"], first_N_iterations=100)
    
    plot_weights_update_curves_compare_runs(trails, fig_save_path)

    # plot_weights_update_curves_compare_layers(trails, fig_save_path)

# def first_study_of_actor_zero_output_pretrain_and_burn_in():
#     trails ={
#     "Zero-output-pretrain + Critic burn-in": 
    
#     {"dir_path": "logs/rsl_rl/pedipulation_residual_RL_INV_test/2025-08-12_17-01-42",
#      "iteration_offset": 0,
#      }, 
#     # 0-10 zero-output-pretrain, 10-50 critic burn-in, 50-200 RL training

#     "Critic burn-in": 
    
#     {"dir_path": "logs/rsl_rl/pedipulation_residual_RL_INV_test/2025-08-12_15-30-09",
#      "iteration_offset": 10,
#      }, 


#      "Cold start RL": 
    
#     {"dir_path": "logs/rsl_rl/pedipulation_residual_RL_INV_test/2025-08-12_15-34-36",
#      "iteration_offset": 50,
#      }, 

    
#     }
#     test(trails, fig_save_path="logs/analysis/plots/weights_update_curves/")


# def critic_investigation_inv():
#     "controlling the actor to be naive MLP, the critic is a hamburger critic with pretrained inv dynamics module"
#     trails ={
#     "Critic burn in + frozen pretrained module": 
#     {"dir_path": "logs/from_cluster/critic_investigation/inv/burn_in_frozen",
#      "iteration_offset": 0,
#      }, 
#     # 0-10 zero-output-pretrain, 10-50 critic burn-in, 50-200 RL training

#     "Critic burn in + unfrozen pretrained module":
#     {"dir_path": "logs/from_cluster/critic_investigation/inv/burn_in_unfrozen",
#      "iteration_offset": 0,
#      },

#      "No critic burn in + frozen pretrained module":
#     {"dir_path": "logs/from_cluster/critic_investigation/inv/no_burn_in_frozen",
#      "iteration_offset": 100,
#      },

#     "No critic burn in + unfrozen pretrained module":
#     {"dir_path": "logs/from_cluster/critic_investigation/inv/no_burn_in_unfrozen",
#      "iteration_offset": 100,
#      },
#     }

#     test(trails, fig_save_path="logs/analysis/plots/critic_investigation_inv/")


def whether_pretrain_inv():
    "controlling the actor to be naive MLP, the critic is a hamburger critic with pretrained inv dynamics module"
    trails ={
    "Mirror RL w/ INV (pedi data)": 
    {"dir_path": "logs/rsl_rl/pedipulation_RL_asymmetric_actor_critic/cold_start_inv_pedidata_unfrozen",
     "iteration_offset": 0,
     }, 
    # 0-10 zero-output-pretrain, 10-50 critic burn-in, 50-200 RL training

    "Mirror RL w/ INV (rand init)":
    {"dir_path": "logs/rsl_rl/pedipulation_RL_asymmetric_actor_critic/cold_start_inv_rand_unfrozen",
     "iteration_offset": 0,
     },
    }

    test(trails, fig_save_path="logs/analysis/plots/critic_investigation_inv/")


if __name__ == "__main__":
    # first_study_of_actor_zero_output_pretrain_and_burn_in()
    # critic_investigation_inv()
    whether_pretrain_inv()