import os
import numpy as np
from torch.multiprocessing import Pool
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

def load_scalar_from_run(run_dir, tag, max_step=1e7):
    """
    Load (step, value) arrays for a single scalar tag from a TensorBoard run directory.
    """
    ea = event_accumulator.EventAccumulator(
        run_dir,
        size_guidance={event_accumulator.SCALARS: 0}
    )
    ea.Reload()
    events = ea.Scalars(tag)
    if not events:
        raise KeyError(f"Tag '{tag}' not found in run {run_dir}")
    steps = np.array([e.step for e in events], dtype=np.int64)
    values = np.array([e.value for e in events], dtype=np.float64)
    sorted_idx = np.argsort(steps)

    steps= steps[sorted_idx]
    values=values[sorted_idx]
    mask = steps < max_step
    mask[np.argmin(mask)]=True
    return steps[mask], values[mask]


def try_tags_for_run(run_dir, tags):
    """
    Attempt to load the first available tag from a list for a run.
    Returns (steps, values, used_tag).
    """
    for tag in tags:
        try:
            steps, values = load_scalar_from_run(run_dir, tag)
            return steps, values, tag
        except KeyError:
            continue
    raise KeyError(f"None of tags {tags} found in run {run_dir}")


def align_runs(runs, strategy='union'):
    """
    Align multiple runs on a common set of steps via union or intersection.
    """
    step_sets = [r[0] for r in runs]
    if strategy == 'intersection':
        common = set(step_sets[0])
        for s in step_sets[1:]:
            common &= set(s)
    else:
        common = set()
        for s in step_sets:
            common |= set(s)
    common_steps = np.array(sorted(common), dtype=np.int64)

    aligned = []
    for steps, values, _ in runs:
        interp_vals = np.interp(common_steps, steps, values)
        aligned.append(interp_vals)

    return common_steps, np.vstack(aligned)


def plot_methods(
    method_runs,
    tags,
    axis_name=None,
    experiment_name=None,
    strategy='union',
    figsize=(8, 5)
):
    """
    Plot mean and std for multiple methods, trying multiple tags per run.

    Args:
        method_runs (dict): method_name -> list of run directories
        tags (list): list of scalar tags to try for each run
        axis_name (str): label for the y-axis; defaults to joined tags
        strategy (str): 'union' or 'intersection'
        figsize (tuple): figure size
    """
    plt.figure(figsize=figsize)
    ax = plt.gca()

    for method_name, run_dirs in method_runs.items():
        runs = []
        used_tags = set()
        for d in run_dirs:
            steps, values, used = try_tags_for_run(d, tags)
            runs.append((steps, values, used))
            used_tags.add(used)

        common_steps, aligned_vals = align_runs(runs, strategy=strategy)
        mean_vals = aligned_vals.mean(axis=0)
        std_vals = aligned_vals.std(axis=0)
        print(experiment_name,axis_name,mean_vals[-1],std_vals[-1])
        ax.plot(common_steps, mean_vals, label=method_name)
        ax.fill_between(common_steps, mean_vals - std_vals, mean_vals + std_vals, alpha=0.3)
        print(f"Method '{method_name}' used tags: {sorted(used_tags)}")

    ax.set_xlabel('step')
    ylabel = axis_name if axis_name is not None else "/".join(tags)
    ax.set_ylabel(ylabel)
    ax.set_title("Mean ± std")
    ax.legend()
    plt.tight_layout()
    plt.savefig(experiment_name+"_"+axis_name+".png")

if __name__ == "__main__":
    pool=Pool(12)
    run_dirs = {"SafeMPO":[
    "runs/lr3e-4_FAST_40STEPs_training_steps_8_kappa100_envs12-1744200041",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarGoal1-v0-1744804724",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory100-SafetyCarGoal1-v0-1744721541",
    "runs/lr3e-4_FAST_20STEPs_training_steps_8_kappa100_envs12-1744183222",
    "runs/lr3e-4_FAST_40STEPs_training_steps_8_kappa100_envs12-1744200041",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarGoal1-v0-1745831618",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarGoal1-v0-1745838713",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarGoal1-v0-1745831618",
    ]}
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyCarGoal", tags=["eval/episodic_cost"], strategy="union",axis_name="costs"))
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyCarGoal", tags=["eval/episodic_return"], strategy="union",axis_name="returns"))
    #plot_methods(run_dirs,experiment_name="SafetyCarGoal", tags=["eval/episodic_cost"], strategy="union",axis_name="costs")
    #plot_methods(run_dirs,experiment_name="SafetyCarGoal", tags=["eval/episodic_return"], strategy="union",axis_name="returns")

    ###
    run_dirs = {"SafeMPO":[
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1745339213",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1745339164",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1745338974",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1745843309",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1745843393",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyPointGoal1-v0-1746017972",

    #"runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory100-SafetyCarGoal1-v0-1744721541",
    #"runs/lr3e-4_FAST_20STEPs_training_steps_8_kappa100_envs12-1744183222",
    #"runs/lr3e-4_FAST_40STEPs_training_steps_8_kappa100_envs12-1744200041",
    ]}
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyPointGoal", tags=["eval/episodic_cost"], strategy="union",axis_name="costs"))
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyPointGoal", tags=["eval/episodic_return"], strategy="union",axis_name="returns"))
    #plot_methods(run_dirs,experiment_name="SafetyPointGoal", tags=["eval/episodic_cost"], strategy="union",axis_name="costs")
    #plot_methods(run_dirs,experiment_name="SafetyPointGoal", tags=["eval/episodic_return"], strategy="union",axis_name="returns")
    
    run_dirs = {"SafeMPO":[
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745339297",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745339246",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745237076",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745913492",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745914168",
    "runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory01-SafetyCarButton1-v0-1745914168",

    #"runs/lr3e-4_FAST_20STEPs_harder_retrace_kappa10_envs12_BS1024_Memory100-SafetyCarGoal1-v0-1744721541",
    #"runs/lr3e-4_FAST_20STEPs_training_steps_8_kappa100_envs12-1744183222",
    #"runs/lr3e-4_FAST_40STEPs_training_steps_8_kappa100_envs12-1744200041",
    ]}
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyCarButton", tags=["eval/episodic_cost"], strategy="union",axis_name="costs"))
    pool.apply_async(plot_methods,kwds=dict(method_runs=run_dirs,experiment_name="SafetyCarButton", tags=["eval/episodic_return"], strategy="union",axis_name="returns"))
    #plot_methods(run_dirs,experiment_name="SafetyCarButton", tags=["eval/episodic_cost"], strategy="union",axis_name="costs")
    #plot_methods(run_dirs,experiment_name="SafetyCarButton", tags=["eval/episodic_return"], strategy="union",axis_name="returns")

    pool.close()
    pool.join()
