In [79]:
import os
import json
import math
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
from plotly.subplots import make_subplots
pio.kaleido.scope.mathjax = None
pd.options.display.max_rows = None
In [80]:
ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso-mini", "miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
MAIN_ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso_alpha0.05"]
MAIN_ALGOS_2 = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso_alpha0.05"]
ABLATION_ALGOS = ["miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]
RENAME_ALGOS = {
    "bc_beta0.0": "BC (β=0.0)",
    "bc_beta0.5": "BC (β=0.5)",
    "bc_beta1.0": "BC (β=1.0)",
    "omapl": "OMAPL",
    "indd": "INDD",
    "vdn": "VDN",
    "miso-sl": "MARL-SL",
    "miso-mini": "MisoDICE (4o-mini)",
    "miso_alpha0.00": "MisoDICE (α=0.00)",
    "miso_alpha0.05": "MisoDICE (ours)",
    "miso_alpha0.10": "MisoDICE (α=0.10)",
    "miso_alpha0.50": "MisoDICE (α=0.50)",
    "miso_alpha1.00": "MisoDICE (α=1.00)",
    "miso_alpha10.00": "MisoDICE (α=10.00)",
}
colors = px.colors.qualitative.Plotly
COLOR_MAPS = {
    "bc_beta0.0": colors[4],
    "bc_beta0.5": colors[9],
    "bc_beta1.0": colors[2],
    "omapl": colors[7],
    "indd": colors[3],
    "vdn": colors[0],
    "miso-sl": colors[8],
    "miso-mini": colors[0],
    "miso_alpha0.00": colors[5],
    "miso_alpha0.05": colors[1],
    "miso_alpha0.10": colors[6],
    "miso_alpha0.50": colors[7],
    "miso_alpha1.00": colors[8],
    "miso_alpha10.00": "#feb406",

    50: colors[0],
    200: colors[1],
    400: colors[2],
    800: colors[3],
    1200: colors[4],
}

LOG_DIR = "logs"
In [81]:
def load_results(path):
    if not os.path.exists(path):
        return {}
    with open(path, "r") as f:
        data = json.load(f)
    return data
In [82]:
def smooth(scalars, weight=0.75):
    last = 0
    smoothed = []
    for num_acc, next_val in enumerate(scalars):
        last = last * weight + (1 - weight) * next_val
        smoothed.append(last / (1 - math.pow(weight, num_acc+1)))
    return smoothed
In [83]:
def load_data(use=False, exsize=200, algos=None):
    data_returns = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    data_winrates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    if not algos:
        algos = ALGOS
    for algo in algos:
        for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
            for seed in range(4):
                if use:
                    if env_name in MAMUJOCO_ENV_NAMES:
                        continue
                    path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
                else:
                    path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
                
                if algo == "omapl":
                    path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/results.json"

                data = load_results(path)
                for step in range(101):
                    result = data[f"step_{step}"] if algo != "omapl" else data
                    if "returns" in result:
                        data_returns[algo][env_name][step].append(np.mean(result["returns"]))
                    if "winrates" in result:
                        data_winrates[algo][env_name][step].append(np.mean(result["winrates"]))
    return data_returns, data_winrates
In [84]:
def analyze_data(data, tag="returns"):
    if isinstance(data, float):
        return "NaN"
    items = [data[step] for step in sorted(data.keys())]
    items = np.array(items)
    items = [smooth(items[:, i]) for i in range(items.shape[-1])]
    items = np.array(items)
    items = items[:, -1]
    mean = np.mean(items)
    std = np.std(items)
    if tag == "returns":
        return f"{mean:.1f} ± {std:.1f}"
    elif tag == "winrates":
        return f"{100*mean:.1f} ± {100*std:.1f}"
    else:
        return items
In [85]:
data_returns, data_winrates = load_data(use=False, algos=MAIN_ALGOS)
pd_returns = pd.DataFrame.from_dict(data_returns).rename(columns=RENAME_ALGOS)
pd_winrates = pd.DataFrame.from_dict(data_winrates).rename(columns=RENAME_ALGOS)
In [86]:
pd_returns.map(analyze_data)
Out[86]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) INDD VDN MisoDICE (ours)
Hopper-v2 123.0 ± 23.1 154.1 ± 44.6 158.1 ± 46.4 135.4 ± 32.8 171.4 ± 19.2 206.5 ± 22.5
Ant-v2 1026.5 ± 77.7 1631.1 ± 51.8 1910.8 ± 70.2 1826.4 ± 16.6 1869.7 ± 9.1 2025.8 ± 3.2
HalfCheetah-v2 -201.5 ± 7.9 -206.5 ± 6.0 -73.4 ± 34.5 -277.9 ± 5.7 -243.9 ± 14.8 -72.1 ± 31.8
2c_vs_64zg 8.4 ± 0.2 9.3 ± 0.3 9.7 ± 0.7 11.8 ± 0.4 13.2 ± 1.5 15.0 ± 1.4
5m_vs_6m 4.9 ± 1.3 6.3 ± 0.6 5.9 ± 0.4 6.6 ± 0.2 6.7 ± 0.1 7.2 ± 0.1
6h_vs_8z 7.0 ± 0.1 7.4 ± 0.1 7.1 ± 0.2 7.4 ± 0.3 7.8 ± 0.1 8.5 ± 0.1
corridor 1.4 ± 0.1 1.5 ± 0.2 3.0 ± 0.6 3.0 ± 0.6 1.7 ± 0.0 4.7 ± 0.6
protoss_5_vs_5 9.8 ± 0.4 10.4 ± 0.4 10.2 ± 0.6 10.4 ± 0.5 11.5 ± 0.6 12.2 ± 0.7
protoss_10_vs_10 9.4 ± 0.2 11.7 ± 0.8 10.1 ± 0.3 10.9 ± 0.5 11.8 ± 0.6 12.6 ± 0.6
protoss_10_vs_11 7.7 ± 0.4 9.4 ± 0.3 8.4 ± 0.5 9.1 ± 0.4 9.5 ± 0.4 10.6 ± 0.4
protoss_20_vs_20 10.5 ± 0.3 10.3 ± 0.4 9.9 ± 0.6 11.3 ± 0.3 11.4 ± 0.3 13.1 ± 0.4
protoss_20_vs_23 7.9 ± 0.3 8.3 ± 0.3 8.3 ± 0.2 9.4 ± 0.2 9.5 ± 0.4 10.4 ± 0.4
terran_5_vs_5 6.5 ± 0.5 7.2 ± 0.7 6.9 ± 0.6 7.6 ± 0.4 8.0 ± 0.7 8.8 ± 0.5
terran_10_vs_10 6.1 ± 0.5 7.1 ± 0.7 6.6 ± 0.6 7.3 ± 0.7 7.6 ± 0.3 8.6 ± 0.3
terran_10_vs_11 4.6 ± 0.2 5.4 ± 0.6 5.0 ± 0.2 5.5 ± 0.1 5.7 ± 0.2 6.2 ± 0.5
terran_20_vs_20 6.5 ± 0.5 7.3 ± 0.5 6.7 ± 0.5 7.8 ± 0.6 8.4 ± 0.5 9.1 ± 0.4
terran_20_vs_23 4.0 ± 0.3 5.0 ± 0.2 4.2 ± 0.3 5.1 ± 0.5 5.0 ± 0.1 5.5 ± 0.4
zerg_5_vs_5 5.5 ± 0.3 6.5 ± 0.3 5.6 ± 0.2 6.2 ± 0.5 6.6 ± 0.3 7.4 ± 0.5
zerg_10_vs_10 7.1 ± 0.2 8.2 ± 0.5 7.2 ± 0.5 8.0 ± 0.3 8.5 ± 0.8 10.0 ± 0.2
zerg_10_vs_11 6.5 ± 0.4 8.2 ± 0.3 6.8 ± 0.1 7.9 ± 0.3 8.4 ± 0.2 9.1 ± 0.6
zerg_20_vs_20 7.3 ± 0.2 9.0 ± 0.3 7.6 ± 0.2 8.1 ± 0.7 8.6 ± 0.4 10.0 ± 0.6
zerg_20_vs_23 7.0 ± 0.2 7.9 ± 0.3 6.9 ± 0.1 8.1 ± 0.2 8.8 ± 0.3 9.5 ± 0.4
In [87]:
pd_winrates.map(lambda x: analyze_data(x, tag="winrates"))
Out[87]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) INDD VDN MisoDICE (ours)
2c_vs_64zg 0.1 ± 0.1 0.5 ± 0.3 1.0 ± 1.0 2.3 ± 0.1 7.1 ± 4.6 8.4 ± 5.9
5m_vs_6m 0.2 ± 0.4 0.7 ± 0.5 0.1 ± 0.1 0.1 ± 0.1 1.1 ± 0.8 1.3 ± 0.5
6h_vs_8z 0.2 ± 0.2 0.0 ± 0.0 0.2 ± 0.3 0.1 ± 0.1 1.0 ± 0.6 1.1 ± 0.8
corridor 0.1 ± 0.1 0.6 ± 0.7 0.3 ± 0.4 0.1 ± 0.1 0.9 ± 0.6 1.4 ± 0.6
protoss_5_vs_5 15.8 ± 2.0 11.8 ± 1.7 13.7 ± 3.4 10.1 ± 1.6 14.3 ± 3.4 18.4 ± 1.3
protoss_10_vs_10 7.4 ± 3.3 9.4 ± 3.2 8.7 ± 2.9 7.2 ± 1.6 9.8 ± 2.1 12.2 ± 1.6
protoss_10_vs_11 1.4 ± 1.0 2.9 ± 1.5 1.7 ± 0.2 1.7 ± 0.2 2.8 ± 0.7 4.1 ± 1.1
protoss_20_vs_20 7.6 ± 1.7 2.9 ± 0.7 5.0 ± 2.4 3.9 ± 0.9 3.6 ± 1.4 8.7 ± 1.8
protoss_20_vs_23 1.4 ± 0.3 0.7 ± 0.5 1.7 ± 1.2 2.0 ± 0.6 1.6 ± 0.8 3.2 ± 0.8
terran_5_vs_5 8.7 ± 2.7 10.3 ± 2.6 10.4 ± 2.3 9.3 ± 1.6 8.8 ± 1.8 14.0 ± 3.5
terran_10_vs_10 7.0 ± 2.0 7.8 ± 3.1 7.8 ± 2.8 6.6 ± 1.0 5.9 ± 2.3 11.3 ± 1.8
terran_10_vs_11 1.5 ± 0.5 2.3 ± 0.8 2.0 ± 0.4 1.2 ± 0.8 2.2 ± 1.1 2.7 ± 1.2
terran_20_vs_20 3.5 ± 1.7 5.6 ± 0.7 3.5 ± 2.8 4.3 ± 1.2 5.5 ± 1.9 7.6 ± 0.7
terran_20_vs_23 0.5 ± 0.6 0.9 ± 0.3 0.5 ± 0.8 0.6 ± 0.6 0.9 ± 0.6 1.7 ± 1.3
zerg_5_vs_5 5.7 ± 0.4 5.2 ± 0.7 5.5 ± 1.4 5.4 ± 1.2 6.0 ± 2.1 6.6 ± 1.0
zerg_10_vs_10 3.7 ± 1.2 3.7 ± 0.9 3.7 ± 1.4 4.5 ± 1.1 4.3 ± 2.0 8.8 ± 1.7
zerg_10_vs_11 3.3 ± 1.0 4.1 ± 1.8 3.2 ± 1.4 3.6 ± 0.5 5.0 ± 1.4 6.7 ± 2.6
zerg_20_vs_20 0.3 ± 0.1 0.9 ± 0.5 0.9 ± 0.3 0.1 ± 0.1 1.1 ± 0.4 1.7 ± 0.4
zerg_20_vs_23 0.7 ± 0.4 0.6 ± 0.4 0.9 ± 0.4 0.6 ± 0.6 1.6 ± 0.4 2.4 ± 0.2
In [88]:
def update_legend(fig, tag="returns", distance=1.1, yrange=None):
    trace_names = []
    for trace in fig.data:
        if trace.name is not None and trace.name not in trace_names:
            trace_names.append(trace.name)
            trace.update(showlegend=True)
        else:
            trace.update(showlegend=False)
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=distance, xanchor="right", x=1))

    fig.update_xaxes(showgrid=True)
    fig.update_yaxes(showgrid=True)

    fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=3))

    if tag == "winrates":
        fig.update_yaxes(tickformat=".0%", dtick=0.1, minor=dict(ticklen=3, nticks=2))
    else:
        fig.update_yaxes(tickformat="~s")
    if yrange is not None:
        fig.update_yaxes(range=yrange)
    return fig
In [89]:
def create_scatters(env_name, data_dict, y_range=None, tag="returns", algos=None):
    fig = go.Figure()
    if not algos:
        algos = MAIN_ALGOS_2
    for algo in algos:
        steps = []
        values = []
        stds = []
        for step in range(101):
            if step < 1:
                continue
            data = np.array(data_dict[algo][env_name][step])
            steps.append(step)
            values.append(data.mean())
            stds.append(data.std())
        values = smooth(values)

        uppers = [value + std for value, std in zip(values, stds)]
        lowers = [value - std for value, std in zip(values, stds)]

        if tag == "winrates":
            uppers = [min(1.0, value) for value in uppers]
            lowers = [max(0.0, value) for value in lowers]

        color = COLOR_MAPS.get(algo, colors[2])
        algo_name = RENAME_ALGOS.get(algo, algo)
        opacity = 0.1
        fig.add_trace(go.Scatter(x=steps, y=values, mode="lines", name=algo_name, line_color=color, line_width=1.5))
        fig.add_trace(go.Scatter(x=steps+steps[::-1], y=uppers+lowers[::-1], fill="toself", fillcolor=color, line_color=color, opacity=opacity, line_width=1, showlegend=False))

    if tag == "winrates":
        tickformat = ".0%"
    else:
        tickformat = "~s"

    fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
    fig.update_layout(height=120, width=180)
    fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=4))
    fig.update_yaxes(range=y_range, tickformat=tickformat)
    fig.update_layout(showlegend=False)
    return fig
In [90]:
def show_fig(plotly_figs, env_names, n_rows, n_cols, height, width, title):
    tag = "returns" if "returns" in title.lower() else "winrates"
    prefix = "ablation_" if "ablation" in title.lower() else ""
    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=env_names)
    if any(env_name in SMACV2_ENV_NAMES for env_name in env_names):
        for i, mode in enumerate(["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]):
            for j, env_name in enumerate(["protoss", "terran", "zerg"]):
                plotly_fig = plotly_figs[f"{env_name}_{mode}"]
                plotly_fig.write_image(f"graphs/{prefix}{env_name}_{mode}_{tag}.pdf")
                fig.add_traces(plotly_fig.data, rows=j+1, cols=i+1)
    else:
        for i, env_name in enumerate(env_names):
            plotly_fig = plotly_figs[env_name]
            plotly_fig.write_image(f"graphs/{prefix}{env_name}_{tag}.pdf")
            fig.add_traces(plotly_fig.data, rows=1, cols=i+1)
    fig.update_layout(template='simple_white')
    fig.update_layout(height=height, width=width, title_text=title)
    fig = update_legend(fig, tag=tag, distance=1.15)
    fig.show("svg")
    fig.update_layout(title_text=None)
    title = title.replace(" - ", "_").replace(" ", "_").replace(":", "_").replace("__", "_")
    print(f"Saving {title}.pdf")
    fig.write_image(f"graphs/{title}.pdf")
In [91]:
plotly_figs = {}
for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, data_returns)

show_fig(plotly_figs, MAMUJOCO_ENV_NAMES, n_rows=1, n_cols=3, height=300, width=950, title="MAMUJOCO - Returns - Rule-based")
show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Returns - Rule-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Returns - Rule-based")
No description has been provided for this image
Saving MAMUJOCO_Returns_Rule-based.pdf
No description has been provided for this image
Saving SMACv1_Returns_Rule-based.pdf
No description has been provided for this image
Saving SMACv2_Returns_Rule-based.pdf
In [92]:
plotly_figs = {}
for env_name in SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, data_winrates, tag="winrates")

show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Winrates - Rule-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Winrates - Rule-based")
No description has been provided for this image
Saving SMACv1_Winrates_Rule-based.pdf
No description has been provided for this image
Saving SMACv2_Winrates_Rule-based.pdf
In [93]:
def show_box_chart(pd_data, env_names, tag="returns"):
    pd_data = pd_data.map(lambda x: analyze_data(x, tag=None))
    fig = go.Figure()
    for i, exsize in enumerate([50, 200, 400, 800, 1200]):
        all_values = []
        for env_name in env_names:
            values = pd_data[exsize][env_name]
            all_values.extend(values)
        fig.add_trace(go.Box(y=all_values, name=exsize, marker_color=colors[0], boxpoints=False))

    if tag == "winrates":
        tickformat = ".0%"
    else:
        tickformat = "~s"
        
    fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
    fig.update_layout(height=140, width=220)
    fig.update_layout(showlegend=False)
    fig.update_yaxes(tickformat=tickformat)
    env_name = env_name.split("_")[0]
    fig.write_image(f"graphs/{env_name}_{tag}_box.pdf")
    fig.show()
In [94]:
ablation_exsize_returns = {}
ablation_exsize_winrates = {}
for exsize in [50, 200, 400, 800, 1200]:
    data_returns, data_winrates = load_data(use=False, exsize=exsize, algos=["miso_alpha0.05"])
    ablation_exsize_returns[exsize] = data_returns["miso_alpha0.05"]
    ablation_exsize_winrates[exsize] = data_winrates["miso_alpha0.05"]
In [95]:
topk_names = {exsize: f"topk={exsize}" for exsize in [50, 200, 400, 800, 1200]}
PROTOSS_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "protoss" in env_name]
TERRAN_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "terran" in env_name]
ZERG_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "zerg" in env_name]
In [96]:
pd_returns = pd.DataFrame.from_dict(ablation_exsize_returns)
pd_returns.rename(columns=topk_names).map(analyze_data)
Out[96]:
topk=50 topk=200 topk=400 topk=800 topk=1200
Hopper-v2 166.8 ± 23.5 206.5 ± 22.5 138.0 ± 17.6 129.8 ± 24.7 150.7 ± 11.4
Ant-v2 2012.1 ± 7.7 2025.8 ± 3.2 1522.7 ± 48.2 1228.5 ± 60.4 1091.0 ± 55.8
HalfCheetah-v2 -118.0 ± 24.0 -72.1 ± 31.8 -263.5 ± 19.9 -231.3 ± 19.0 -224.1 ± 11.5
2c_vs_64zg 14.3 ± 0.6 15.0 ± 1.4 10.6 ± 0.5 11.1 ± 0.4 10.5 ± 0.8
5m_vs_6m 6.7 ± 0.5 7.2 ± 0.1 6.5 ± 0.4 6.6 ± 0.1 6.3 ± 0.2
6h_vs_8z 7.9 ± 0.2 8.5 ± 0.1 7.8 ± 0.1 7.5 ± 0.2 7.4 ± 0.1
corridor 2.7 ± 1.1 4.7 ± 0.6 2.1 ± 0.2 1.8 ± 0.1 1.7 ± 0.1
protoss_5_vs_5 11.6 ± 0.9 12.2 ± 0.7 11.2 ± 0.1 10.9 ± 0.4 10.5 ± 0.6
protoss_10_vs_10 11.5 ± 0.3 12.6 ± 0.6 11.5 ± 0.4 11.0 ± 0.4 10.6 ± 0.4
protoss_10_vs_11 10.2 ± 0.2 10.6 ± 0.4 10.0 ± 0.4 9.7 ± 0.4 8.9 ± 0.4
protoss_20_vs_20 12.4 ± 0.2 13.1 ± 0.4 11.7 ± 0.4 11.9 ± 0.3 11.5 ± 0.7
protoss_20_vs_23 10.2 ± 0.3 10.4 ± 0.4 9.9 ± 0.3 9.6 ± 0.3 8.9 ± 0.3
terran_5_vs_5 7.3 ± 0.4 8.8 ± 0.5 7.8 ± 0.9 7.1 ± 0.6 7.0 ± 0.4
terran_10_vs_10 7.8 ± 0.5 8.6 ± 0.3 7.8 ± 0.5 7.0 ± 0.6 6.9 ± 0.4
terran_10_vs_11 5.7 ± 0.4 6.2 ± 0.5 5.3 ± 0.1 5.5 ± 0.2 5.1 ± 0.3
terran_20_vs_20 8.2 ± 0.6 9.1 ± 0.4 8.3 ± 0.1 8.1 ± 0.5 7.8 ± 0.2
terran_20_vs_23 5.3 ± 0.1 5.5 ± 0.4 5.1 ± 0.5 4.9 ± 0.3 4.9 ± 0.5
zerg_5_vs_5 7.2 ± 0.4 7.4 ± 0.5 7.4 ± 0.7 6.8 ± 0.5 6.2 ± 0.2
zerg_10_vs_10 9.1 ± 0.2 10.0 ± 0.2 9.2 ± 0.3 8.5 ± 0.3 8.6 ± 0.4
zerg_10_vs_11 8.6 ± 0.3 9.1 ± 0.6 8.7 ± 0.6 8.1 ± 0.6 8.2 ± 0.2
zerg_20_vs_20 8.4 ± 0.4 10.0 ± 0.6 8.7 ± 0.6 8.3 ± 0.6 8.3 ± 0.7
zerg_20_vs_23 9.2 ± 0.3 9.5 ± 0.4 8.6 ± 0.1 8.5 ± 0.2 7.9 ± 0.3
In [97]:
pd_winrates = pd.DataFrame.from_dict(ablation_exsize_winrates)
pd_winrates.rename(columns=topk_names).map(lambda x: analyze_data(x, tag="winrates"))
Out[97]:
topk=50 topk=200 topk=400 topk=800 topk=1200
2c_vs_64zg 18.2 ± 8.1 8.4 ± 5.9 5.7 ± 4.1 1.7 ± 0.5 1.7 ± 2.5
5m_vs_6m 0.0 ± 0.0 1.3 ± 0.5 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
6h_vs_8z 0.0 ± 0.0 1.1 ± 0.8 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
corridor 0.0 ± 0.0 1.4 ± 0.6 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
protoss_5_vs_5 11.3 ± 3.8 18.4 ± 1.3 10.7 ± 2.1 10.5 ± 2.7 10.1 ± 2.3
protoss_10_vs_10 4.1 ± 1.9 12.2 ± 1.6 7.0 ± 0.9 7.5 ± 1.5 5.4 ± 0.6
protoss_10_vs_11 1.6 ± 1.1 4.1 ± 1.1 2.6 ± 1.0 2.7 ± 0.8 2.0 ± 1.3
protoss_20_vs_20 3.1 ± 0.1 8.7 ± 1.8 4.0 ± 1.9 5.4 ± 2.4 4.4 ± 0.7
protoss_20_vs_23 1.9 ± 0.6 3.2 ± 0.8 2.1 ± 0.6 1.1 ± 0.6 0.6 ± 0.7
terran_5_vs_5 8.8 ± 0.7 14.0 ± 3.5 8.9 ± 2.3 9.4 ± 2.2 5.1 ± 0.3
terran_10_vs_10 4.8 ± 1.8 11.3 ± 1.8 6.8 ± 3.2 7.7 ± 2.5 4.7 ± 1.5
terran_10_vs_11 1.0 ± 0.3 2.7 ± 1.2 1.1 ± 0.6 1.8 ± 0.7 1.4 ± 0.5
terran_20_vs_20 4.8 ± 2.1 7.6 ± 0.7 4.6 ± 0.7 4.7 ± 1.4 3.5 ± 0.7
terran_20_vs_23 0.3 ± 0.1 1.7 ± 1.3 0.4 ± 0.5 0.4 ± 0.3 0.4 ± 0.3
zerg_5_vs_5 5.6 ± 0.8 6.6 ± 1.0 7.5 ± 1.3 5.1 ± 1.3 3.2 ± 1.2
zerg_10_vs_10 4.3 ± 0.6 8.8 ± 1.7 5.6 ± 1.3 5.7 ± 0.7 4.0 ± 1.8
zerg_10_vs_11 3.5 ± 2.1 6.7 ± 2.6 4.0 ± 1.1 3.5 ± 2.2 3.1 ± 1.6
zerg_20_vs_20 0.1 ± 0.1 1.7 ± 0.4 0.4 ± 0.5 0.4 ± 0.5 0.4 ± 0.7
zerg_20_vs_23 0.4 ± 0.4 2.4 ± 0.2 1.7 ± 0.5 1.2 ± 0.9 0.7 ± 0.6
In [98]:
show_box_chart(pd_returns, PROTOSS_ENVS, tag="returns")
show_box_chart(pd_returns, TERRAN_ENVS, tag="returns")
show_box_chart(pd_returns, ZERG_ENVS, tag="returns")
In [99]:
show_box_chart(pd_winrates, PROTOSS_ENVS, tag="winrates")
show_box_chart(pd_winrates, TERRAN_ENVS, tag="winrates")
show_box_chart(pd_winrates, ZERG_ENVS, tag="winrates")
In [100]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_returns, algos=[50, 200, 400, 800, 1200])

show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Returns - Rule-based")
No description has been provided for this image
Saving Ablation_SMACv2_Returns_Rule-based.pdf
In [101]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_winrates, algos=[50, 200, 400, 800, 1200], tag="winrates")

show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Winrates - Rule-based")
No description has been provided for this image
Saving Ablation_SMACv2_Winrates_Rule-based.pdf
In [102]:
RENAME_ALGOS_2 = {
    "miso_alpha0.00": "MisoDICE (α=0.00)",
    "miso_alpha0.05": "MisoDICE (α=0.05)",
    "miso_alpha0.10": "MisoDICE (α=0.10)",
    "miso_alpha0.50": "MisoDICE (α=0.50)",
    "miso_alpha1.00": "MisoDICE (α=1.00)",
    "miso_alpha10.00": "MisoDICE (α=10.0)",
}
data_returns, data_winrates = load_data(use=False, algos=ABLATION_ALGOS)
pd_returns = pd.DataFrame.from_dict(data_returns).rename(columns=RENAME_ALGOS_2)
pd_winrates = pd.DataFrame.from_dict(data_winrates).rename(columns=RENAME_ALGOS_2)
In [103]:
pd_returns.map(analyze_data)
Out[103]:
MisoDICE (α=0.00) MisoDICE (α=0.05) MisoDICE (α=0.10) MisoDICE (α=0.50) MisoDICE (α=1.00) MisoDICE (α=10.0)
Hopper-v2 200.0 ± 46.6 206.5 ± 22.5 165.3 ± 37.9 175.3 ± 36.8 171.8 ± 27.7 144.2 ± 43.6
Ant-v2 2039.6 ± 4.1 2025.8 ± 3.2 1988.9 ± 8.2 1964.5 ± 5.2 1929.7 ± 33.7 1403.3 ± 67.2
HalfCheetah-v2 -108.1 ± 21.8 -72.1 ± 31.8 -86.2 ± 31.3 -125.6 ± 16.5 -185.9 ± 20.3 -238.8 ± 14.3
2c_vs_64zg 13.0 ± 0.7 15.0 ± 1.4 12.5 ± 0.5 12.1 ± 0.5 11.4 ± 0.4 10.8 ± 0.3
5m_vs_6m 6.4 ± 1.3 7.2 ± 0.1 6.3 ± 1.0 6.1 ± 1.5 6.5 ± 0.5 6.6 ± 0.2
6h_vs_8z 8.2 ± 0.2 8.5 ± 0.1 8.0 ± 0.2 7.9 ± 0.2 7.7 ± 0.2 7.6 ± 0.1
corridor 4.7 ± 0.7 4.7 ± 0.6 4.6 ± 0.7 4.3 ± 0.6 2.9 ± 0.6 1.7 ± 0.1
protoss_5_vs_5 11.7 ± 0.5 12.2 ± 0.7 11.4 ± 0.5 11.4 ± 0.8 11.4 ± 0.4 11.5 ± 0.4
protoss_10_vs_10 11.8 ± 0.3 12.6 ± 0.6 12.1 ± 0.6 11.6 ± 0.3 11.7 ± 0.6 11.9 ± 0.8
protoss_10_vs_11 10.0 ± 0.3 10.6 ± 0.4 10.0 ± 0.7 9.9 ± 0.2 9.6 ± 0.4 9.7 ± 0.2
protoss_20_vs_20 12.2 ± 0.4 13.1 ± 0.4 12.2 ± 0.3 11.8 ± 0.4 11.6 ± 0.2 11.8 ± 0.5
protoss_20_vs_23 10.2 ± 0.4 10.4 ± 0.4 9.9 ± 0.5 9.8 ± 0.4 9.7 ± 0.1 9.4 ± 0.3
terran_5_vs_5 8.4 ± 0.8 8.8 ± 0.5 8.2 ± 0.1 8.3 ± 0.8 7.8 ± 0.9 7.9 ± 0.7
terran_10_vs_10 8.3 ± 0.2 8.6 ± 0.3 8.0 ± 0.3 8.3 ± 0.5 7.7 ± 0.4 7.6 ± 0.6
terran_10_vs_11 5.8 ± 0.3 6.2 ± 0.5 5.6 ± 0.2 5.7 ± 0.6 5.6 ± 0.2 5.2 ± 0.2
terran_20_vs_20 8.1 ± 0.9 9.1 ± 0.4 8.1 ± 0.6 7.8 ± 0.6 7.5 ± 0.7 8.3 ± 0.5
terran_20_vs_23 5.2 ± 0.4 5.5 ± 0.4 5.1 ± 0.2 5.1 ± 0.3 4.9 ± 0.2 5.0 ± 0.3
zerg_5_vs_5 7.2 ± 0.3 7.4 ± 0.5 6.7 ± 0.5 6.4 ± 0.3 6.3 ± 0.6 6.3 ± 0.4
zerg_10_vs_10 9.4 ± 0.3 10.0 ± 0.2 9.0 ± 0.3 8.8 ± 0.2 8.5 ± 0.4 8.6 ± 0.3
zerg_10_vs_11 9.0 ± 0.3 9.1 ± 0.6 8.4 ± 0.4 8.7 ± 0.3 8.5 ± 0.1 8.3 ± 0.4
zerg_20_vs_20 9.0 ± 0.6 10.0 ± 0.6 9.0 ± 0.7 9.0 ± 0.5 8.6 ± 0.6 8.8 ± 0.6
zerg_20_vs_23 8.9 ± 0.0 9.5 ± 0.4 8.6 ± 0.2 8.7 ± 0.4 8.4 ± 0.3 8.5 ± 0.3
In [ ]:
pd_winrates.map(lambda x: analyze_data(x, tag="winrates"))
Out[ ]:
MisoDICE (α=0.00) MisoDICE (α=0.05) MisoDICE (α=0.10) MisoDICE (α=0.50) MisoDICE (α=1.00) MisoDICE (α=10.0)
2c_vs_64zg 3.9 ± 1.0 8.4 ± 5.9 3.4 ± 1.2 2.9 ± 0.8 2.7 ± 1.1 1.6 ± 0.9
5m_vs_6m 1.1 ± 1.0 1.3 ± 0.5 0.6 ± 0.4 0.7 ± 0.4 0.9 ± 0.7 0.6 ± 0.3
6h_vs_8z 1.0 ± 0.8 1.1 ± 0.8 1.1 ± 0.8 0.8 ± 0.5 0.4 ± 0.3 1.0 ± 0.5
corridor 0.4 ± 0.2 1.4 ± 0.6 1.2 ± 0.5 0.5 ± 0.4 1.4 ± 0.7 0.9 ± 0.4
protoss_5_vs_5 11.6 ± 2.4 18.4 ± 1.3 11.0 ± 1.3 11.3 ± 2.0 12.0 ± 2.0 16.3 ± 2.1
protoss_10_vs_10 8.2 ± 2.0 12.2 ± 1.6 8.8 ± 3.2 9.4 ± 2.6 8.7 ± 2.7 10.3 ± 1.2
protoss_10_vs_11 1.9 ± 0.3 4.1 ± 1.1 3.1 ± 1.3 3.0 ± 1.0 2.2 ± 1.1 3.6 ± 1.2
protoss_20_vs_20 4.9 ± 1.3 8.7 ± 1.8 4.0 ± 1.5 6.3 ± 1.7 4.8 ± 1.0 5.2 ± 1.1
protoss_20_vs_23 3.0 ± 1.4 3.2 ± 0.8 2.0 ± 0.6 2.9 ± 1.4 3.2 ± 0.6 2.2 ± 1.4
terran_5_vs_5 10.2 ± 1.7 14.0 ± 3.5 10.5 ± 2.8 11.4 ± 1.5 10.4 ± 1.6 9.2 ± 1.6
terran_10_vs_10 8.7 ± 1.8 11.3 ± 1.8 8.4 ± 1.4 6.4 ± 1.4 8.0 ± 3.1 8.7 ± 3.9
terran_10_vs_11 2.0 ± 0.5 2.7 ± 1.2 3.1 ± 1.1 2.1 ± 1.0 1.7 ± 0.5 1.8 ± 0.8
terran_20_vs_20 4.1 ± 1.8 7.6 ± 0.7 4.3 ± 2.2 5.3 ± 1.5 4.3 ± 2.6 7.8 ± 1.4
terran_20_vs_23 1.4 ± 0.7 1.7 ± 1.3 1.2 ± 0.6 1.7 ± 1.2 0.8 ± 0.4 1.6 ± 0.7
zerg_5_vs_5 5.9 ± 1.6 6.6 ± 1.0 6.2 ± 0.8 4.7 ± 1.1 5.2 ± 0.8 5.8 ± 0.8
zerg_10_vs_10 4.6 ± 1.4 8.8 ± 1.7 5.8 ± 1.0 5.5 ± 0.8 5.9 ± 1.1 6.4 ± 1.4
zerg_10_vs_11 4.1 ± 1.4 6.7 ± 2.6 4.6 ± 0.7 3.8 ± 0.9 4.6 ± 1.1 5.6 ± 1.2
zerg_20_vs_20 1.1 ± 0.4 1.7 ± 0.4 0.8 ± 0.5 1.4 ± 1.3 1.4 ± 0.5 2.5 ± 0.8
zerg_20_vs_23 1.3 ± 0.3 2.4 ± 0.2 0.9 ± 0.4 1.8 ± 0.1 1.9 ± 1.7 1.8 ± 0.4
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.