In [79]:
import os
import json
import math
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
from plotly.subplots import make_subplots
pio.kaleido.scope.mathjax = None
pd.options.display.max_rows = None
In [80]:
ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso-mini", "miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
MAIN_ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso_alpha0.05"]
MAIN_ALGOS_2 = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso_alpha0.05"]
ABLATION_ALGOS = ["miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]
RENAME_ALGOS = {
"bc_beta0.0": "BC (β=0.0)",
"bc_beta0.5": "BC (β=0.5)",
"bc_beta1.0": "BC (β=1.0)",
"omapl": "OMAPL",
"indd": "INDD",
"vdn": "VDN",
"miso-sl": "MARL-SL",
"miso-mini": "MisoDICE (4o-mini)",
"miso_alpha0.00": "MisoDICE (α=0.00)",
"miso_alpha0.05": "MisoDICE (ours)",
"miso_alpha0.10": "MisoDICE (α=0.10)",
"miso_alpha0.50": "MisoDICE (α=0.50)",
"miso_alpha1.00": "MisoDICE (α=1.00)",
"miso_alpha10.00": "MisoDICE (α=10.00)",
}
colors = px.colors.qualitative.Plotly
COLOR_MAPS = {
"bc_beta0.0": colors[4],
"bc_beta0.5": colors[9],
"bc_beta1.0": colors[2],
"omapl": colors[7],
"indd": colors[3],
"vdn": colors[0],
"miso-sl": colors[8],
"miso-mini": colors[0],
"miso_alpha0.00": colors[5],
"miso_alpha0.05": colors[1],
"miso_alpha0.10": colors[6],
"miso_alpha0.50": colors[7],
"miso_alpha1.00": colors[8],
"miso_alpha10.00": "#feb406",
50: colors[0],
200: colors[1],
400: colors[2],
800: colors[3],
1200: colors[4],
}
LOG_DIR = "logs"
In [81]:
def load_results(path):
if not os.path.exists(path):
return {}
with open(path, "r") as f:
data = json.load(f)
return data
In [82]:
def smooth(scalars, weight=0.75):
last = 0
smoothed = []
for num_acc, next_val in enumerate(scalars):
last = last * weight + (1 - weight) * next_val
smoothed.append(last / (1 - math.pow(weight, num_acc+1)))
return smoothed
In [83]:
def load_data(use=False, exsize=200, algos=None):
data_returns = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
data_winrates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
if not algos:
algos = ALGOS
for algo in algos:
for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
for seed in range(4):
if use:
if env_name in MAMUJOCO_ENV_NAMES:
continue
path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
else:
path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
if algo == "omapl":
path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/results.json"
data = load_results(path)
for step in range(101):
result = data[f"step_{step}"] if algo != "omapl" else data
if "returns" in result:
data_returns[algo][env_name][step].append(np.mean(result["returns"]))
if "winrates" in result:
data_winrates[algo][env_name][step].append(np.mean(result["winrates"]))
return data_returns, data_winrates
In [84]:
def analyze_data(data, tag="returns"):
if isinstance(data, float):
return "NaN"
items = [data[step] for step in sorted(data.keys())]
items = np.array(items)
items = [smooth(items[:, i]) for i in range(items.shape[-1])]
items = np.array(items)
items = items[:, -1]
mean = np.mean(items)
std = np.std(items)
if tag == "returns":
return f"{mean:.1f} ± {std:.1f}"
elif tag == "winrates":
return f"{100*mean:.1f} ± {100*std:.1f}"
else:
return items
In [85]:
data_returns, data_winrates = load_data(use=False, algos=MAIN_ALGOS)
pd_returns = pd.DataFrame.from_dict(data_returns).rename(columns=RENAME_ALGOS)
pd_winrates = pd.DataFrame.from_dict(data_winrates).rename(columns=RENAME_ALGOS)
In [86]:
pd_returns.map(analyze_data)
Out[86]:
| BC (β=0.0) | BC (β=0.5) | BC (β=1.0) | INDD | VDN | MisoDICE (ours) | |
|---|---|---|---|---|---|---|
| Hopper-v2 | 123.0 ± 23.1 | 154.1 ± 44.6 | 158.1 ± 46.4 | 135.4 ± 32.8 | 171.4 ± 19.2 | 206.5 ± 22.5 |
| Ant-v2 | 1026.5 ± 77.7 | 1631.1 ± 51.8 | 1910.8 ± 70.2 | 1826.4 ± 16.6 | 1869.7 ± 9.1 | 2025.8 ± 3.2 |
| HalfCheetah-v2 | -201.5 ± 7.9 | -206.5 ± 6.0 | -73.4 ± 34.5 | -277.9 ± 5.7 | -243.9 ± 14.8 | -72.1 ± 31.8 |
| 2c_vs_64zg | 8.4 ± 0.2 | 9.3 ± 0.3 | 9.7 ± 0.7 | 11.8 ± 0.4 | 13.2 ± 1.5 | 15.0 ± 1.4 |
| 5m_vs_6m | 4.9 ± 1.3 | 6.3 ± 0.6 | 5.9 ± 0.4 | 6.6 ± 0.2 | 6.7 ± 0.1 | 7.2 ± 0.1 |
| 6h_vs_8z | 7.0 ± 0.1 | 7.4 ± 0.1 | 7.1 ± 0.2 | 7.4 ± 0.3 | 7.8 ± 0.1 | 8.5 ± 0.1 |
| corridor | 1.4 ± 0.1 | 1.5 ± 0.2 | 3.0 ± 0.6 | 3.0 ± 0.6 | 1.7 ± 0.0 | 4.7 ± 0.6 |
| protoss_5_vs_5 | 9.8 ± 0.4 | 10.4 ± 0.4 | 10.2 ± 0.6 | 10.4 ± 0.5 | 11.5 ± 0.6 | 12.2 ± 0.7 |
| protoss_10_vs_10 | 9.4 ± 0.2 | 11.7 ± 0.8 | 10.1 ± 0.3 | 10.9 ± 0.5 | 11.8 ± 0.6 | 12.6 ± 0.6 |
| protoss_10_vs_11 | 7.7 ± 0.4 | 9.4 ± 0.3 | 8.4 ± 0.5 | 9.1 ± 0.4 | 9.5 ± 0.4 | 10.6 ± 0.4 |
| protoss_20_vs_20 | 10.5 ± 0.3 | 10.3 ± 0.4 | 9.9 ± 0.6 | 11.3 ± 0.3 | 11.4 ± 0.3 | 13.1 ± 0.4 |
| protoss_20_vs_23 | 7.9 ± 0.3 | 8.3 ± 0.3 | 8.3 ± 0.2 | 9.4 ± 0.2 | 9.5 ± 0.4 | 10.4 ± 0.4 |
| terran_5_vs_5 | 6.5 ± 0.5 | 7.2 ± 0.7 | 6.9 ± 0.6 | 7.6 ± 0.4 | 8.0 ± 0.7 | 8.8 ± 0.5 |
| terran_10_vs_10 | 6.1 ± 0.5 | 7.1 ± 0.7 | 6.6 ± 0.6 | 7.3 ± 0.7 | 7.6 ± 0.3 | 8.6 ± 0.3 |
| terran_10_vs_11 | 4.6 ± 0.2 | 5.4 ± 0.6 | 5.0 ± 0.2 | 5.5 ± 0.1 | 5.7 ± 0.2 | 6.2 ± 0.5 |
| terran_20_vs_20 | 6.5 ± 0.5 | 7.3 ± 0.5 | 6.7 ± 0.5 | 7.8 ± 0.6 | 8.4 ± 0.5 | 9.1 ± 0.4 |
| terran_20_vs_23 | 4.0 ± 0.3 | 5.0 ± 0.2 | 4.2 ± 0.3 | 5.1 ± 0.5 | 5.0 ± 0.1 | 5.5 ± 0.4 |
| zerg_5_vs_5 | 5.5 ± 0.3 | 6.5 ± 0.3 | 5.6 ± 0.2 | 6.2 ± 0.5 | 6.6 ± 0.3 | 7.4 ± 0.5 |
| zerg_10_vs_10 | 7.1 ± 0.2 | 8.2 ± 0.5 | 7.2 ± 0.5 | 8.0 ± 0.3 | 8.5 ± 0.8 | 10.0 ± 0.2 |
| zerg_10_vs_11 | 6.5 ± 0.4 | 8.2 ± 0.3 | 6.8 ± 0.1 | 7.9 ± 0.3 | 8.4 ± 0.2 | 9.1 ± 0.6 |
| zerg_20_vs_20 | 7.3 ± 0.2 | 9.0 ± 0.3 | 7.6 ± 0.2 | 8.1 ± 0.7 | 8.6 ± 0.4 | 10.0 ± 0.6 |
| zerg_20_vs_23 | 7.0 ± 0.2 | 7.9 ± 0.3 | 6.9 ± 0.1 | 8.1 ± 0.2 | 8.8 ± 0.3 | 9.5 ± 0.4 |
In [87]:
pd_winrates.map(lambda x: analyze_data(x, tag="winrates"))
Out[87]:
| BC (β=0.0) | BC (β=0.5) | BC (β=1.0) | INDD | VDN | MisoDICE (ours) | |
|---|---|---|---|---|---|---|
| 2c_vs_64zg | 0.1 ± 0.1 | 0.5 ± 0.3 | 1.0 ± 1.0 | 2.3 ± 0.1 | 7.1 ± 4.6 | 8.4 ± 5.9 |
| 5m_vs_6m | 0.2 ± 0.4 | 0.7 ± 0.5 | 0.1 ± 0.1 | 0.1 ± 0.1 | 1.1 ± 0.8 | 1.3 ± 0.5 |
| 6h_vs_8z | 0.2 ± 0.2 | 0.0 ± 0.0 | 0.2 ± 0.3 | 0.1 ± 0.1 | 1.0 ± 0.6 | 1.1 ± 0.8 |
| corridor | 0.1 ± 0.1 | 0.6 ± 0.7 | 0.3 ± 0.4 | 0.1 ± 0.1 | 0.9 ± 0.6 | 1.4 ± 0.6 |
| protoss_5_vs_5 | 15.8 ± 2.0 | 11.8 ± 1.7 | 13.7 ± 3.4 | 10.1 ± 1.6 | 14.3 ± 3.4 | 18.4 ± 1.3 |
| protoss_10_vs_10 | 7.4 ± 3.3 | 9.4 ± 3.2 | 8.7 ± 2.9 | 7.2 ± 1.6 | 9.8 ± 2.1 | 12.2 ± 1.6 |
| protoss_10_vs_11 | 1.4 ± 1.0 | 2.9 ± 1.5 | 1.7 ± 0.2 | 1.7 ± 0.2 | 2.8 ± 0.7 | 4.1 ± 1.1 |
| protoss_20_vs_20 | 7.6 ± 1.7 | 2.9 ± 0.7 | 5.0 ± 2.4 | 3.9 ± 0.9 | 3.6 ± 1.4 | 8.7 ± 1.8 |
| protoss_20_vs_23 | 1.4 ± 0.3 | 0.7 ± 0.5 | 1.7 ± 1.2 | 2.0 ± 0.6 | 1.6 ± 0.8 | 3.2 ± 0.8 |
| terran_5_vs_5 | 8.7 ± 2.7 | 10.3 ± 2.6 | 10.4 ± 2.3 | 9.3 ± 1.6 | 8.8 ± 1.8 | 14.0 ± 3.5 |
| terran_10_vs_10 | 7.0 ± 2.0 | 7.8 ± 3.1 | 7.8 ± 2.8 | 6.6 ± 1.0 | 5.9 ± 2.3 | 11.3 ± 1.8 |
| terran_10_vs_11 | 1.5 ± 0.5 | 2.3 ± 0.8 | 2.0 ± 0.4 | 1.2 ± 0.8 | 2.2 ± 1.1 | 2.7 ± 1.2 |
| terran_20_vs_20 | 3.5 ± 1.7 | 5.6 ± 0.7 | 3.5 ± 2.8 | 4.3 ± 1.2 | 5.5 ± 1.9 | 7.6 ± 0.7 |
| terran_20_vs_23 | 0.5 ± 0.6 | 0.9 ± 0.3 | 0.5 ± 0.8 | 0.6 ± 0.6 | 0.9 ± 0.6 | 1.7 ± 1.3 |
| zerg_5_vs_5 | 5.7 ± 0.4 | 5.2 ± 0.7 | 5.5 ± 1.4 | 5.4 ± 1.2 | 6.0 ± 2.1 | 6.6 ± 1.0 |
| zerg_10_vs_10 | 3.7 ± 1.2 | 3.7 ± 0.9 | 3.7 ± 1.4 | 4.5 ± 1.1 | 4.3 ± 2.0 | 8.8 ± 1.7 |
| zerg_10_vs_11 | 3.3 ± 1.0 | 4.1 ± 1.8 | 3.2 ± 1.4 | 3.6 ± 0.5 | 5.0 ± 1.4 | 6.7 ± 2.6 |
| zerg_20_vs_20 | 0.3 ± 0.1 | 0.9 ± 0.5 | 0.9 ± 0.3 | 0.1 ± 0.1 | 1.1 ± 0.4 | 1.7 ± 0.4 |
| zerg_20_vs_23 | 0.7 ± 0.4 | 0.6 ± 0.4 | 0.9 ± 0.4 | 0.6 ± 0.6 | 1.6 ± 0.4 | 2.4 ± 0.2 |
In [88]:
def update_legend(fig, tag="returns", distance=1.1, yrange=None):
trace_names = []
for trace in fig.data:
if trace.name is not None and trace.name not in trace_names:
trace_names.append(trace.name)
trace.update(showlegend=True)
else:
trace.update(showlegend=False)
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=distance, xanchor="right", x=1))
fig.update_xaxes(showgrid=True)
fig.update_yaxes(showgrid=True)
fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=3))
if tag == "winrates":
fig.update_yaxes(tickformat=".0%", dtick=0.1, minor=dict(ticklen=3, nticks=2))
else:
fig.update_yaxes(tickformat="~s")
if yrange is not None:
fig.update_yaxes(range=yrange)
return fig
In [89]:
def create_scatters(env_name, data_dict, y_range=None, tag="returns", algos=None):
fig = go.Figure()
if not algos:
algos = MAIN_ALGOS_2
for algo in algos:
steps = []
values = []
stds = []
for step in range(101):
if step < 1:
continue
data = np.array(data_dict[algo][env_name][step])
steps.append(step)
values.append(data.mean())
stds.append(data.std())
values = smooth(values)
uppers = [value + std for value, std in zip(values, stds)]
lowers = [value - std for value, std in zip(values, stds)]
if tag == "winrates":
uppers = [min(1.0, value) for value in uppers]
lowers = [max(0.0, value) for value in lowers]
color = COLOR_MAPS.get(algo, colors[2])
algo_name = RENAME_ALGOS.get(algo, algo)
opacity = 0.1
fig.add_trace(go.Scatter(x=steps, y=values, mode="lines", name=algo_name, line_color=color, line_width=1.5))
fig.add_trace(go.Scatter(x=steps+steps[::-1], y=uppers+lowers[::-1], fill="toself", fillcolor=color, line_color=color, opacity=opacity, line_width=1, showlegend=False))
if tag == "winrates":
tickformat = ".0%"
else:
tickformat = "~s"
fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
fig.update_layout(height=120, width=180)
fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=4))
fig.update_yaxes(range=y_range, tickformat=tickformat)
fig.update_layout(showlegend=False)
return fig
In [90]:
def show_fig(plotly_figs, env_names, n_rows, n_cols, height, width, title):
tag = "returns" if "returns" in title.lower() else "winrates"
prefix = "ablation_" if "ablation" in title.lower() else ""
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=env_names)
if any(env_name in SMACV2_ENV_NAMES for env_name in env_names):
for i, mode in enumerate(["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]):
for j, env_name in enumerate(["protoss", "terran", "zerg"]):
plotly_fig = plotly_figs[f"{env_name}_{mode}"]
plotly_fig.write_image(f"graphs/{prefix}{env_name}_{mode}_{tag}.pdf")
fig.add_traces(plotly_fig.data, rows=j+1, cols=i+1)
else:
for i, env_name in enumerate(env_names):
plotly_fig = plotly_figs[env_name]
plotly_fig.write_image(f"graphs/{prefix}{env_name}_{tag}.pdf")
fig.add_traces(plotly_fig.data, rows=1, cols=i+1)
fig.update_layout(template='simple_white')
fig.update_layout(height=height, width=width, title_text=title)
fig = update_legend(fig, tag=tag, distance=1.15)
fig.show("svg")
fig.update_layout(title_text=None)
title = title.replace(" - ", "_").replace(" ", "_").replace(":", "_").replace("__", "_")
print(f"Saving {title}.pdf")
fig.write_image(f"graphs/{title}.pdf")
In [91]:
plotly_figs = {}
for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
plotly_figs[env_name] = create_scatters(env_name, data_returns)
show_fig(plotly_figs, MAMUJOCO_ENV_NAMES, n_rows=1, n_cols=3, height=300, width=950, title="MAMUJOCO - Returns - Rule-based")
show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Returns - Rule-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Returns - Rule-based")
Saving MAMUJOCO_Returns_Rule-based.pdf
Saving SMACv1_Returns_Rule-based.pdf
Saving SMACv2_Returns_Rule-based.pdf
In [92]:
plotly_figs = {}
for env_name in SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
plotly_figs[env_name] = create_scatters(env_name, data_winrates, tag="winrates")
show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Winrates - Rule-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Winrates - Rule-based")
Saving SMACv1_Winrates_Rule-based.pdf
Saving SMACv2_Winrates_Rule-based.pdf
In [93]:
def show_box_chart(pd_data, env_names, tag="returns"):
pd_data = pd_data.map(lambda x: analyze_data(x, tag=None))
fig = go.Figure()
for i, exsize in enumerate([50, 200, 400, 800, 1200]):
all_values = []
for env_name in env_names:
values = pd_data[exsize][env_name]
all_values.extend(values)
fig.add_trace(go.Box(y=all_values, name=exsize, marker_color=colors[0], boxpoints=False))
if tag == "winrates":
tickformat = ".0%"
else:
tickformat = "~s"
fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
fig.update_layout(height=140, width=220)
fig.update_layout(showlegend=False)
fig.update_yaxes(tickformat=tickformat)
env_name = env_name.split("_")[0]
fig.write_image(f"graphs/{env_name}_{tag}_box.pdf")
fig.show()
In [94]:
ablation_exsize_returns = {}
ablation_exsize_winrates = {}
for exsize in [50, 200, 400, 800, 1200]:
data_returns, data_winrates = load_data(use=False, exsize=exsize, algos=["miso_alpha0.05"])
ablation_exsize_returns[exsize] = data_returns["miso_alpha0.05"]
ablation_exsize_winrates[exsize] = data_winrates["miso_alpha0.05"]
In [95]:
topk_names = {exsize: f"topk={exsize}" for exsize in [50, 200, 400, 800, 1200]}
PROTOSS_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "protoss" in env_name]
TERRAN_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "terran" in env_name]
ZERG_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "zerg" in env_name]
In [96]:
pd_returns = pd.DataFrame.from_dict(ablation_exsize_returns)
pd_returns.rename(columns=topk_names).map(analyze_data)
Out[96]:
| topk=50 | topk=200 | topk=400 | topk=800 | topk=1200 | |
|---|---|---|---|---|---|
| Hopper-v2 | 166.8 ± 23.5 | 206.5 ± 22.5 | 138.0 ± 17.6 | 129.8 ± 24.7 | 150.7 ± 11.4 |
| Ant-v2 | 2012.1 ± 7.7 | 2025.8 ± 3.2 | 1522.7 ± 48.2 | 1228.5 ± 60.4 | 1091.0 ± 55.8 |
| HalfCheetah-v2 | -118.0 ± 24.0 | -72.1 ± 31.8 | -263.5 ± 19.9 | -231.3 ± 19.0 | -224.1 ± 11.5 |
| 2c_vs_64zg | 14.3 ± 0.6 | 15.0 ± 1.4 | 10.6 ± 0.5 | 11.1 ± 0.4 | 10.5 ± 0.8 |
| 5m_vs_6m | 6.7 ± 0.5 | 7.2 ± 0.1 | 6.5 ± 0.4 | 6.6 ± 0.1 | 6.3 ± 0.2 |
| 6h_vs_8z | 7.9 ± 0.2 | 8.5 ± 0.1 | 7.8 ± 0.1 | 7.5 ± 0.2 | 7.4 ± 0.1 |
| corridor | 2.7 ± 1.1 | 4.7 ± 0.6 | 2.1 ± 0.2 | 1.8 ± 0.1 | 1.7 ± 0.1 |
| protoss_5_vs_5 | 11.6 ± 0.9 | 12.2 ± 0.7 | 11.2 ± 0.1 | 10.9 ± 0.4 | 10.5 ± 0.6 |
| protoss_10_vs_10 | 11.5 ± 0.3 | 12.6 ± 0.6 | 11.5 ± 0.4 | 11.0 ± 0.4 | 10.6 ± 0.4 |
| protoss_10_vs_11 | 10.2 ± 0.2 | 10.6 ± 0.4 | 10.0 ± 0.4 | 9.7 ± 0.4 | 8.9 ± 0.4 |
| protoss_20_vs_20 | 12.4 ± 0.2 | 13.1 ± 0.4 | 11.7 ± 0.4 | 11.9 ± 0.3 | 11.5 ± 0.7 |
| protoss_20_vs_23 | 10.2 ± 0.3 | 10.4 ± 0.4 | 9.9 ± 0.3 | 9.6 ± 0.3 | 8.9 ± 0.3 |
| terran_5_vs_5 | 7.3 ± 0.4 | 8.8 ± 0.5 | 7.8 ± 0.9 | 7.1 ± 0.6 | 7.0 ± 0.4 |
| terran_10_vs_10 | 7.8 ± 0.5 | 8.6 ± 0.3 | 7.8 ± 0.5 | 7.0 ± 0.6 | 6.9 ± 0.4 |
| terran_10_vs_11 | 5.7 ± 0.4 | 6.2 ± 0.5 | 5.3 ± 0.1 | 5.5 ± 0.2 | 5.1 ± 0.3 |
| terran_20_vs_20 | 8.2 ± 0.6 | 9.1 ± 0.4 | 8.3 ± 0.1 | 8.1 ± 0.5 | 7.8 ± 0.2 |
| terran_20_vs_23 | 5.3 ± 0.1 | 5.5 ± 0.4 | 5.1 ± 0.5 | 4.9 ± 0.3 | 4.9 ± 0.5 |
| zerg_5_vs_5 | 7.2 ± 0.4 | 7.4 ± 0.5 | 7.4 ± 0.7 | 6.8 ± 0.5 | 6.2 ± 0.2 |
| zerg_10_vs_10 | 9.1 ± 0.2 | 10.0 ± 0.2 | 9.2 ± 0.3 | 8.5 ± 0.3 | 8.6 ± 0.4 |
| zerg_10_vs_11 | 8.6 ± 0.3 | 9.1 ± 0.6 | 8.7 ± 0.6 | 8.1 ± 0.6 | 8.2 ± 0.2 |
| zerg_20_vs_20 | 8.4 ± 0.4 | 10.0 ± 0.6 | 8.7 ± 0.6 | 8.3 ± 0.6 | 8.3 ± 0.7 |
| zerg_20_vs_23 | 9.2 ± 0.3 | 9.5 ± 0.4 | 8.6 ± 0.1 | 8.5 ± 0.2 | 7.9 ± 0.3 |
In [97]:
pd_winrates = pd.DataFrame.from_dict(ablation_exsize_winrates)
pd_winrates.rename(columns=topk_names).map(lambda x: analyze_data(x, tag="winrates"))
Out[97]:
| topk=50 | topk=200 | topk=400 | topk=800 | topk=1200 | |
|---|---|---|---|---|---|
| 2c_vs_64zg | 18.2 ± 8.1 | 8.4 ± 5.9 | 5.7 ± 4.1 | 1.7 ± 0.5 | 1.7 ± 2.5 |
| 5m_vs_6m | 0.0 ± 0.0 | 1.3 ± 0.5 | 0.0 ± 0.0 | 0.0 ± 0.0 | 0.0 ± 0.0 |
| 6h_vs_8z | 0.0 ± 0.0 | 1.1 ± 0.8 | 0.0 ± 0.0 | 0.0 ± 0.0 | 0.0 ± 0.0 |
| corridor | 0.0 ± 0.0 | 1.4 ± 0.6 | 0.0 ± 0.0 | 0.0 ± 0.0 | 0.0 ± 0.0 |
| protoss_5_vs_5 | 11.3 ± 3.8 | 18.4 ± 1.3 | 10.7 ± 2.1 | 10.5 ± 2.7 | 10.1 ± 2.3 |
| protoss_10_vs_10 | 4.1 ± 1.9 | 12.2 ± 1.6 | 7.0 ± 0.9 | 7.5 ± 1.5 | 5.4 ± 0.6 |
| protoss_10_vs_11 | 1.6 ± 1.1 | 4.1 ± 1.1 | 2.6 ± 1.0 | 2.7 ± 0.8 | 2.0 ± 1.3 |
| protoss_20_vs_20 | 3.1 ± 0.1 | 8.7 ± 1.8 | 4.0 ± 1.9 | 5.4 ± 2.4 | 4.4 ± 0.7 |
| protoss_20_vs_23 | 1.9 ± 0.6 | 3.2 ± 0.8 | 2.1 ± 0.6 | 1.1 ± 0.6 | 0.6 ± 0.7 |
| terran_5_vs_5 | 8.8 ± 0.7 | 14.0 ± 3.5 | 8.9 ± 2.3 | 9.4 ± 2.2 | 5.1 ± 0.3 |
| terran_10_vs_10 | 4.8 ± 1.8 | 11.3 ± 1.8 | 6.8 ± 3.2 | 7.7 ± 2.5 | 4.7 ± 1.5 |
| terran_10_vs_11 | 1.0 ± 0.3 | 2.7 ± 1.2 | 1.1 ± 0.6 | 1.8 ± 0.7 | 1.4 ± 0.5 |
| terran_20_vs_20 | 4.8 ± 2.1 | 7.6 ± 0.7 | 4.6 ± 0.7 | 4.7 ± 1.4 | 3.5 ± 0.7 |
| terran_20_vs_23 | 0.3 ± 0.1 | 1.7 ± 1.3 | 0.4 ± 0.5 | 0.4 ± 0.3 | 0.4 ± 0.3 |
| zerg_5_vs_5 | 5.6 ± 0.8 | 6.6 ± 1.0 | 7.5 ± 1.3 | 5.1 ± 1.3 | 3.2 ± 1.2 |
| zerg_10_vs_10 | 4.3 ± 0.6 | 8.8 ± 1.7 | 5.6 ± 1.3 | 5.7 ± 0.7 | 4.0 ± 1.8 |
| zerg_10_vs_11 | 3.5 ± 2.1 | 6.7 ± 2.6 | 4.0 ± 1.1 | 3.5 ± 2.2 | 3.1 ± 1.6 |
| zerg_20_vs_20 | 0.1 ± 0.1 | 1.7 ± 0.4 | 0.4 ± 0.5 | 0.4 ± 0.5 | 0.4 ± 0.7 |
| zerg_20_vs_23 | 0.4 ± 0.4 | 2.4 ± 0.2 | 1.7 ± 0.5 | 1.2 ± 0.9 | 0.7 ± 0.6 |
In [98]:
show_box_chart(pd_returns, PROTOSS_ENVS, tag="returns")
show_box_chart(pd_returns, TERRAN_ENVS, tag="returns")
show_box_chart(pd_returns, ZERG_ENVS, tag="returns")
In [99]:
show_box_chart(pd_winrates, PROTOSS_ENVS, tag="winrates")
show_box_chart(pd_winrates, TERRAN_ENVS, tag="winrates")
show_box_chart(pd_winrates, ZERG_ENVS, tag="winrates")
In [100]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_returns, algos=[50, 200, 400, 800, 1200])
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Returns - Rule-based")
Saving Ablation_SMACv2_Returns_Rule-based.pdf
In [101]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_winrates, algos=[50, 200, 400, 800, 1200], tag="winrates")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Winrates - Rule-based")
Saving Ablation_SMACv2_Winrates_Rule-based.pdf
In [102]:
RENAME_ALGOS_2 = {
"miso_alpha0.00": "MisoDICE (α=0.00)",
"miso_alpha0.05": "MisoDICE (α=0.05)",
"miso_alpha0.10": "MisoDICE (α=0.10)",
"miso_alpha0.50": "MisoDICE (α=0.50)",
"miso_alpha1.00": "MisoDICE (α=1.00)",
"miso_alpha10.00": "MisoDICE (α=10.0)",
}
data_returns, data_winrates = load_data(use=False, algos=ABLATION_ALGOS)
pd_returns = pd.DataFrame.from_dict(data_returns).rename(columns=RENAME_ALGOS_2)
pd_winrates = pd.DataFrame.from_dict(data_winrates).rename(columns=RENAME_ALGOS_2)
In [103]:
pd_returns.map(analyze_data)
Out[103]:
| MisoDICE (α=0.00) | MisoDICE (α=0.05) | MisoDICE (α=0.10) | MisoDICE (α=0.50) | MisoDICE (α=1.00) | MisoDICE (α=10.0) | |
|---|---|---|---|---|---|---|
| Hopper-v2 | 200.0 ± 46.6 | 206.5 ± 22.5 | 165.3 ± 37.9 | 175.3 ± 36.8 | 171.8 ± 27.7 | 144.2 ± 43.6 |
| Ant-v2 | 2039.6 ± 4.1 | 2025.8 ± 3.2 | 1988.9 ± 8.2 | 1964.5 ± 5.2 | 1929.7 ± 33.7 | 1403.3 ± 67.2 |
| HalfCheetah-v2 | -108.1 ± 21.8 | -72.1 ± 31.8 | -86.2 ± 31.3 | -125.6 ± 16.5 | -185.9 ± 20.3 | -238.8 ± 14.3 |
| 2c_vs_64zg | 13.0 ± 0.7 | 15.0 ± 1.4 | 12.5 ± 0.5 | 12.1 ± 0.5 | 11.4 ± 0.4 | 10.8 ± 0.3 |
| 5m_vs_6m | 6.4 ± 1.3 | 7.2 ± 0.1 | 6.3 ± 1.0 | 6.1 ± 1.5 | 6.5 ± 0.5 | 6.6 ± 0.2 |
| 6h_vs_8z | 8.2 ± 0.2 | 8.5 ± 0.1 | 8.0 ± 0.2 | 7.9 ± 0.2 | 7.7 ± 0.2 | 7.6 ± 0.1 |
| corridor | 4.7 ± 0.7 | 4.7 ± 0.6 | 4.6 ± 0.7 | 4.3 ± 0.6 | 2.9 ± 0.6 | 1.7 ± 0.1 |
| protoss_5_vs_5 | 11.7 ± 0.5 | 12.2 ± 0.7 | 11.4 ± 0.5 | 11.4 ± 0.8 | 11.4 ± 0.4 | 11.5 ± 0.4 |
| protoss_10_vs_10 | 11.8 ± 0.3 | 12.6 ± 0.6 | 12.1 ± 0.6 | 11.6 ± 0.3 | 11.7 ± 0.6 | 11.9 ± 0.8 |
| protoss_10_vs_11 | 10.0 ± 0.3 | 10.6 ± 0.4 | 10.0 ± 0.7 | 9.9 ± 0.2 | 9.6 ± 0.4 | 9.7 ± 0.2 |
| protoss_20_vs_20 | 12.2 ± 0.4 | 13.1 ± 0.4 | 12.2 ± 0.3 | 11.8 ± 0.4 | 11.6 ± 0.2 | 11.8 ± 0.5 |
| protoss_20_vs_23 | 10.2 ± 0.4 | 10.4 ± 0.4 | 9.9 ± 0.5 | 9.8 ± 0.4 | 9.7 ± 0.1 | 9.4 ± 0.3 |
| terran_5_vs_5 | 8.4 ± 0.8 | 8.8 ± 0.5 | 8.2 ± 0.1 | 8.3 ± 0.8 | 7.8 ± 0.9 | 7.9 ± 0.7 |
| terran_10_vs_10 | 8.3 ± 0.2 | 8.6 ± 0.3 | 8.0 ± 0.3 | 8.3 ± 0.5 | 7.7 ± 0.4 | 7.6 ± 0.6 |
| terran_10_vs_11 | 5.8 ± 0.3 | 6.2 ± 0.5 | 5.6 ± 0.2 | 5.7 ± 0.6 | 5.6 ± 0.2 | 5.2 ± 0.2 |
| terran_20_vs_20 | 8.1 ± 0.9 | 9.1 ± 0.4 | 8.1 ± 0.6 | 7.8 ± 0.6 | 7.5 ± 0.7 | 8.3 ± 0.5 |
| terran_20_vs_23 | 5.2 ± 0.4 | 5.5 ± 0.4 | 5.1 ± 0.2 | 5.1 ± 0.3 | 4.9 ± 0.2 | 5.0 ± 0.3 |
| zerg_5_vs_5 | 7.2 ± 0.3 | 7.4 ± 0.5 | 6.7 ± 0.5 | 6.4 ± 0.3 | 6.3 ± 0.6 | 6.3 ± 0.4 |
| zerg_10_vs_10 | 9.4 ± 0.3 | 10.0 ± 0.2 | 9.0 ± 0.3 | 8.8 ± 0.2 | 8.5 ± 0.4 | 8.6 ± 0.3 |
| zerg_10_vs_11 | 9.0 ± 0.3 | 9.1 ± 0.6 | 8.4 ± 0.4 | 8.7 ± 0.3 | 8.5 ± 0.1 | 8.3 ± 0.4 |
| zerg_20_vs_20 | 9.0 ± 0.6 | 10.0 ± 0.6 | 9.0 ± 0.7 | 9.0 ± 0.5 | 8.6 ± 0.6 | 8.8 ± 0.6 |
| zerg_20_vs_23 | 8.9 ± 0.0 | 9.5 ± 0.4 | 8.6 ± 0.2 | 8.7 ± 0.4 | 8.4 ± 0.3 | 8.5 ± 0.3 |
In [ ]:
pd_winrates.map(lambda x: analyze_data(x, tag="winrates"))
Out[ ]:
| MisoDICE (α=0.00) | MisoDICE (α=0.05) | MisoDICE (α=0.10) | MisoDICE (α=0.50) | MisoDICE (α=1.00) | MisoDICE (α=10.0) | |
|---|---|---|---|---|---|---|
| 2c_vs_64zg | 3.9 ± 1.0 | 8.4 ± 5.9 | 3.4 ± 1.2 | 2.9 ± 0.8 | 2.7 ± 1.1 | 1.6 ± 0.9 |
| 5m_vs_6m | 1.1 ± 1.0 | 1.3 ± 0.5 | 0.6 ± 0.4 | 0.7 ± 0.4 | 0.9 ± 0.7 | 0.6 ± 0.3 |
| 6h_vs_8z | 1.0 ± 0.8 | 1.1 ± 0.8 | 1.1 ± 0.8 | 0.8 ± 0.5 | 0.4 ± 0.3 | 1.0 ± 0.5 |
| corridor | 0.4 ± 0.2 | 1.4 ± 0.6 | 1.2 ± 0.5 | 0.5 ± 0.4 | 1.4 ± 0.7 | 0.9 ± 0.4 |
| protoss_5_vs_5 | 11.6 ± 2.4 | 18.4 ± 1.3 | 11.0 ± 1.3 | 11.3 ± 2.0 | 12.0 ± 2.0 | 16.3 ± 2.1 |
| protoss_10_vs_10 | 8.2 ± 2.0 | 12.2 ± 1.6 | 8.8 ± 3.2 | 9.4 ± 2.6 | 8.7 ± 2.7 | 10.3 ± 1.2 |
| protoss_10_vs_11 | 1.9 ± 0.3 | 4.1 ± 1.1 | 3.1 ± 1.3 | 3.0 ± 1.0 | 2.2 ± 1.1 | 3.6 ± 1.2 |
| protoss_20_vs_20 | 4.9 ± 1.3 | 8.7 ± 1.8 | 4.0 ± 1.5 | 6.3 ± 1.7 | 4.8 ± 1.0 | 5.2 ± 1.1 |
| protoss_20_vs_23 | 3.0 ± 1.4 | 3.2 ± 0.8 | 2.0 ± 0.6 | 2.9 ± 1.4 | 3.2 ± 0.6 | 2.2 ± 1.4 |
| terran_5_vs_5 | 10.2 ± 1.7 | 14.0 ± 3.5 | 10.5 ± 2.8 | 11.4 ± 1.5 | 10.4 ± 1.6 | 9.2 ± 1.6 |
| terran_10_vs_10 | 8.7 ± 1.8 | 11.3 ± 1.8 | 8.4 ± 1.4 | 6.4 ± 1.4 | 8.0 ± 3.1 | 8.7 ± 3.9 |
| terran_10_vs_11 | 2.0 ± 0.5 | 2.7 ± 1.2 | 3.1 ± 1.1 | 2.1 ± 1.0 | 1.7 ± 0.5 | 1.8 ± 0.8 |
| terran_20_vs_20 | 4.1 ± 1.8 | 7.6 ± 0.7 | 4.3 ± 2.2 | 5.3 ± 1.5 | 4.3 ± 2.6 | 7.8 ± 1.4 |
| terran_20_vs_23 | 1.4 ± 0.7 | 1.7 ± 1.3 | 1.2 ± 0.6 | 1.7 ± 1.2 | 0.8 ± 0.4 | 1.6 ± 0.7 |
| zerg_5_vs_5 | 5.9 ± 1.6 | 6.6 ± 1.0 | 6.2 ± 0.8 | 4.7 ± 1.1 | 5.2 ± 0.8 | 5.8 ± 0.8 |
| zerg_10_vs_10 | 4.6 ± 1.4 | 8.8 ± 1.7 | 5.8 ± 1.0 | 5.5 ± 0.8 | 5.9 ± 1.1 | 6.4 ± 1.4 |
| zerg_10_vs_11 | 4.1 ± 1.4 | 6.7 ± 2.6 | 4.6 ± 0.7 | 3.8 ± 0.9 | 4.6 ± 1.1 | 5.6 ± 1.2 |
| zerg_20_vs_20 | 1.1 ± 0.4 | 1.7 ± 0.4 | 0.8 ± 0.5 | 1.4 ± 1.3 | 1.4 ± 0.5 | 2.5 ± 0.8 |
| zerg_20_vs_23 | 1.3 ± 0.3 | 2.4 ± 0.2 | 0.9 ± 0.4 | 1.8 ± 0.1 | 1.9 ± 1.7 | 1.8 ± 0.4 |
The Kernel crashed while executing code in the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.