import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

plt.style.use('seaborn-v0_8-whitegrid')
output_dir = 'matrix_vis'
os.makedirs(output_dir, exist_ok=True)
algorithms = {
    "MAPPO": {'color': '#FFD700', 'label': 'MAPPO'},
    "MAPPO_positive": {'color': '#FF8C00', 'label': 'Optimistic MAPPO'},
    "HAPPO": {'color': '#32CD32', 'label': 'HAPPO'},

    "MAAC": {'color': '#8A2BE2', 'label': 'MAAC'},
    "CORAAC": {'color': '#FF4500', 'label': 'CORA-AC'},

    "COMA": {'color': '#FF1493', 'label': 'COMA'},
    "VDPPO": {'color': '#1E90FF', 'label': 'VDPPO'},  # DodgerBlue
    "QMIXPPO": {'color': '#4169E1', 'label': 'QMIXPPO'},  # RoyalBlue
    "LICA": {'color': '#4682B4', 'label': 'LICA'},  # SteelBlue
    "CORAPPO_Qsa": {'color': '#DC143C', 'label': 'CORA-PPO'},
}

# env_name = "dispersion"
# env_name = "give_way"
env_name = "navigation"
# env_name = "multi_give_way"

algo_max_steps = {
    # "MAAC": 40000,
    # "CORAAC": 40000,
}
default_max_step = 5e6
smooth_window = 4


def load_tb_data(log_dir):
    ea = event_accumulator.EventAccumulator(log_dir)
    try:
        ea.Reload()
        events = ea.Scalars('Return')
        steps = [e.step for e in events]
        values = [e.value for e in events]
        return np.array(steps), np.array(values)
    except:
        return None, None


def process_algo_runs(algo_name):
    run_pattern = os.path.join("runs", f"{algo_name}_{env_name}_Repeat*")
    all_runs = []

    for run_dir in glob.glob(run_pattern):
        steps, values = load_tb_data(run_dir)
        if steps is None or len(steps) == 0:
            continue

        max_step = algo_max_steps.get(algo_name, default_max_step)
        mask = steps <= max_step
        steps = steps[mask]
        values = values[mask]

        if len(steps) < smooth_window:
            continue

        step_interp = np.linspace(0, steps[-1], 300)
        value_interp = np.interp(step_interp, steps, values)
        all_runs.append(value_interp)

    if not all_runs:
        return None

    min_len = min(len(r) for r in all_runs)
    aligned = [r[:min_len] for r in all_runs]
    mean = np.mean(aligned, axis=0)
    std = np.std(aligned, axis=0)
    ci95 = 1.96 * std / np.sqrt(len(aligned))

    weights = np.ones(smooth_window) / smooth_window
    mean_smooth = np.convolve(mean, weights, mode='valid')
    ci95_smooth = np.convolve(ci95, weights, mode='valid')
    x = np.linspace(0, step_interp[-1], len(mean_smooth))

    return x, mean_smooth, ci95_smooth


plt.figure(figsize=(5, 4), dpi=300)

for algo_name, cfg in algorithms.items():
    result = process_algo_runs(algo_name)
    if result is None or len(result) != 3:
        continue

    x, mean, ci95 = result
    plt.plot(x, mean, label=cfg['label'], color=cfg['color'], linewidth=3)
    plt.fill_between(x, mean - ci95, mean + ci95, color=cfg['color'], alpha=0.25)

plt.xlabel("Training Steps", fontsize=12)
plt.ylabel("Return", fontsize=12)
plt.legend(loc='best', fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f"VMAS_{env_name}_comparison.pdf"), bbox_inches='tight')
plt.show()