In [112]:
import os
import json
import math
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from collections import defaultdict
from plotly.subplots import make_subplots
pio.kaleido.scope.mathjax = None
pd.options.display.max_rows = None
In [113]:
ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "omapl", "indd", "vdn", "miso-sl", "miso-mini", "miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
MAIN_ALGOS = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "omapl", "indd", "vdn", "miso-sl", "miso_alpha0.05"]
MAIN_ALGOS_2 = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso-sl", "miso_alpha0.05"]
ABLATION_ALGOS = ["miso_alpha0.00", "miso_alpha0.05", "miso_alpha0.10", "miso_alpha0.50", "miso_alpha1.00", "miso_alpha10.00"]
SMACV1_ENV_NAMES = ["2c_vs_64zg", "5m_vs_6m", "6h_vs_8z", "corridor"]
SMACV2_ENV_NAMES = [f"{map_name}_{map_mode}" for map_name in ["protoss", "terran", "zerg"] for map_mode in ["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]]
MAMUJOCO_ENV_NAMES = ["Hopper-v2", "Ant-v2", "HalfCheetah-v2"]
RENAME_ALGOS = {
    "bc_beta0.0": "BC (β=0.0)",
    "bc_beta0.5": "BC (β=0.5)",
    "bc_beta1.0": "BC (β=1.0)",
    "omapl": "OMAPL",
    "indd": "INDD",
    "vdn": "VDN",
    "miso-sl": "MARL-SL",
    "miso-mini": "MisoDICE (4o-mini)",
    "miso_alpha0.00": "MisoDICE (α=0.00)",
    "miso_alpha0.05": "MisoDICE (ours)",
    "miso_alpha0.10": "MisoDICE (α=0.10)",
    "miso_alpha0.50": "MisoDICE (α=0.50)",
    "miso_alpha1.00": "MisoDICE (α=1.00)",
    "miso_alpha10.00": "MisoDICE (α=10.00)",
}
colors = px.colors.qualitative.Plotly
COLOR_MAPS = {
    "bc_beta0.0": colors[4],
    "bc_beta0.5": colors[9],
    "bc_beta1.0": colors[2],
    "omapl": colors[7],
    "indd": colors[3],
    "vdn": colors[0],
    "miso-sl": colors[8],
    "miso-mini": colors[0],
    "miso_alpha0.00": colors[5],
    "miso_alpha0.05": colors[1],
    "miso_alpha0.10": colors[6],
    "miso_alpha0.50": colors[7],
    "miso_alpha1.00": colors[8],
    "miso_alpha10.00": "#feb406",

    50: colors[0],
    200: colors[1],
    400: colors[2],
    800: colors[3],
    1200: colors[4],
}

LOG_DIR = "logs"
In [114]:
def darken_color(hex_color, factor=0.92):
    r = int(hex_color[1:3], 16)
    g = int(hex_color[3:5], 16)
    b = int(hex_color[5:7], 16)

    r = int(r * factor)
    g = int(g * factor)
    b = int(b * factor)

    r = max(0, min(r, 255))
    g = max(0, min(g, 255))
    b = max(0, min(b, 255))

    return "#{:02x}{:02x}{:02x}".format(r, g, b)
In [115]:
COLOR_MAPS = {k: darken_color(v) for k, v in COLOR_MAPS.items()}
In [116]:
def load_results(path):
    if not os.path.exists(path):
        return {}
    with open(path, "r") as f:
        data = json.load(f)
    return data
In [117]:
def smooth(scalars, weight=0.75):
    last = 0
    smoothed = []
    for num_acc, next_val in enumerate(scalars):
        last = last * weight + (1 - weight) * next_val
        smoothed.append(last / (1 - math.pow(weight, num_acc+1)))
    return smoothed
In [118]:
def load_data(use_llm=False, exsize=200, algos=None):
    data_returns = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    data_winrates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    if not algos:
        algos = ALGOS
    for algo in algos:
        for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
            for seed in range(4):
                if use_llm:
                    if env_name in MAMUJOCO_ENV_NAMES:
                        continue
                    path = f"{LOG_DIR}/{algo}/{env_name}_llm/seed{seed}/exsize{exsize}/results.json"
                else:
                    path = f"{LOG_DIR}/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
                
                if algo == "omapl":
                    path = f"{LOG_DIR}/{algo}/{env_name}_llm/seed{seed}/results.json"

                data = load_results(path)
                for step in range(101):
                    result = data[f"step_{step}"] if algo != "omapl" else data
                    if "returns" in result:
                        data_returns[algo][env_name][step].append(np.mean(result["returns"]))
                    if "winrates" in result:
                        data_winrates[algo][env_name][step].append(np.mean(result["winrates"]))
    return data_returns, data_winrates
In [119]:
def analyze_data(data, tag="returns"):
    if isinstance(data, float):
        return "NaN"
    items = [data[step] for step in sorted(data.keys())]
    items = np.array(items)
    items = [smooth(items[:, i]) for i in range(items.shape[-1])]
    items = np.array(items)
    items = items[:, -1]
    mean = np.mean(items)
    std = np.std(items)
    if tag == "returns":
        return f"{mean:.1f} ± {std:.1f}"
    elif tag == "winrates":
        return f"{100*mean:.1f} ± {100*std:.1f}"
    else:
        return items
In [120]:
data_returns_llm, data_winrates_llm = load_data(use_llm=True, algos=MAIN_ALGOS)
pd_returns_llm = pd.DataFrame.from_dict(data_returns_llm).rename(columns=RENAME_ALGOS)
pd_winrates_llm = pd.DataFrame.from_dict(data_winrates_llm).rename(columns=RENAME_ALGOS)
In [121]:
pd_returns_llm.map(analyze_data)
Out[121]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) OMAPL INDD VDN MARL-SL MisoDICE (ours)
2c_vs_64zg 8.5 ± 0.1 9.7 ± 0.3 12.6 ± 0.3 12.2 ± 0.4 14.6 ± 1.0 14.0 ± 1.6 12.7 ± 0.6 16.4 ± 1.3
5m_vs_6m 5.0 ± 1.1 6.7 ± 0.0 6.1 ± 0.1 5.7 ± 0.2 6.7 ± 0.1 6.8 ± 0.1 6.2 ± 1.4 7.3 ± 0.1
6h_vs_8z 7.0 ± 0.0 7.4 ± 0.0 7.2 ± 0.1 6.6 ± 0.2 7.5 ± 0.2 7.8 ± 0.1 8.2 ± 0.2 8.7 ± 0.2
corridor 1.5 ± 0.1 1.5 ± 0.2 4.3 ± 0.7 2.2 ± 1.3 4.4 ± 1.2 1.8 ± 0.2 4.7 ± 0.6 5.8 ± 0.8
protoss_5_vs_5 9.2 ± 0.1 11.7 ± 0.5 10.2 ± 0.5 9.6 ± 1.1 10.9 ± 0.1 11.6 ± 0.3 11.5 ± 0.2 12.4 ± 0.5
protoss_10_vs_10 10.3 ± 0.6 11.8 ± 0.5 10.6 ± 0.2 10.1 ± 0.9 11.0 ± 0.7 11.9 ± 0.4 12.4 ± 0.2 12.9 ± 0.2
protoss_10_vs_11 8.2 ± 0.4 9.6 ± 0.4 8.7 ± 0.3 8.5 ± 1.2 9.4 ± 0.4 9.9 ± 0.3 10.4 ± 0.1 10.7 ± 0.4
protoss_20_vs_20 10.1 ± 0.2 10.4 ± 0.5 10.5 ± 0.3 9.4 ± 0.4 11.4 ± 0.5 13.1 ± 0.4 12.1 ± 0.5 13.5 ± 0.5
protoss_20_vs_23 8.1 ± 0.2 8.6 ± 0.3 8.3 ± 0.2 7.9 ± 0.3 9.6 ± 0.3 9.6 ± 0.3 10.3 ± 0.4 10.6 ± 0.2
terran_5_vs_5 6.5 ± 0.8 8.1 ± 0.5 7.1 ± 0.6 6.2 ± 0.6 7.9 ± 0.5 8.1 ± 0.4 8.3 ± 1.0 9.1 ± 0.3
terran_10_vs_10 6.6 ± 0.3 7.4 ± 0.4 6.7 ± 0.6 6.9 ± 1.1 7.6 ± 0.4 7.7 ± 0.2 8.0 ± 0.4 9.1 ± 1.3
terran_10_vs_11 4.7 ± 0.2 5.7 ± 0.3 5.2 ± 0.3 4.2 ± 0.6 5.7 ± 0.5 5.7 ± 0.4 6.0 ± 0.2 6.4 ± 0.5
terran_20_vs_20 6.9 ± 0.4 7.9 ± 0.8 6.7 ± 0.2 6.9 ± 0.5 8.0 ± 0.5 8.6 ± 0.3 8.2 ± 0.5 9.2 ± 0.6
terran_20_vs_23 4.0 ± 0.3 5.1 ± 0.4 4.3 ± 0.3 4.3 ± 0.4 5.1 ± 0.4 5.1 ± 0.6 5.6 ± 0.3 5.6 ± 0.4
zerg_5_vs_5 5.7 ± 0.5 6.6 ± 0.4 5.9 ± 0.3 6.1 ± 0.5 6.4 ± 0.2 7.1 ± 0.5 7.1 ± 0.9 7.5 ± 0.1
zerg_10_vs_10 7.3 ± 0.1 8.7 ± 0.6 7.4 ± 0.7 6.8 ± 0.6 8.2 ± 0.2 9.0 ± 0.4 9.7 ± 0.5 10.2 ± 0.6
zerg_10_vs_11 7.3 ± 0.2 8.3 ± 0.4 7.3 ± 0.5 7.2 ± 0.4 8.0 ± 0.2 8.8 ± 0.4 9.1 ± 0.2 9.4 ± 0.3
zerg_20_vs_20 7.4 ± 0.6 9.0 ± 0.5 7.7 ± 0.2 6.9 ± 0.5 8.3 ± 0.4 8.8 ± 0.6 9.0 ± 0.5 10.2 ± 0.6
zerg_20_vs_23 7.1 ± 0.3 7.9 ± 0.3 7.0 ± 0.2 7.1 ± 0.4 8.2 ± 0.4 8.8 ± 0.2 8.7 ± 0.5 9.5 ± 0.2
In [122]:
pd_winrates_llm.map(lambda x: analyze_data(x, tag="winrates"))
Out[122]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) OMAPL INDD VDN MARL-SL MisoDICE (ours)
2c_vs_64zg 0.2 ± 0.2 0.5 ± 0.3 8.9 ± 2.9 3.9 ± 3.4 11.7 ± 5.5 10.6 ± 6.0 2.7 ± 1.5 13.0 ± 9.0
5m_vs_6m 0.2 ± 0.4 0.9 ± 0.6 0.1 ± 0.1 0.0 ± 0.0 0.2 ± 0.2 1.1 ± 0.8 0.9 ± 0.9 1.2 ± 0.5
6h_vs_8z 0.2 ± 0.2 0.0 ± 0.0 0.2 ± 0.3 0.0 ± 0.0 0.1 ± 0.1 1.0 ± 0.6 1.2 ± 0.1 1.1 ± 0.8
corridor 0.1 ± 0.1 0.6 ± 0.7 0.3 ± 0.4 0.0 ± 0.0 0.1 ± 0.1 0.9 ± 0.6 0.7 ± 0.7 1.4 ± 0.6
protoss_5_vs_5 13.8 ± 2.7 17.5 ± 3.8 14.2 ± 2.4 14.1 ± 10.2 12.4 ± 3.1 15.6 ± 4.5 10.8 ± 1.6 20.7 ± 0.9
protoss_10_vs_10 12.1 ± 2.3 12.7 ± 0.9 11.3 ± 2.9 11.7 ± 3.4 8.9 ± 1.4 11.8 ± 4.3 9.5 ± 1.7 14.1 ± 2.1
protoss_10_vs_11 2.1 ± 0.9 3.5 ± 2.0 1.8 ± 0.6 0.8 ± 1.4 2.0 ± 0.4 3.5 ± 0.3 2.9 ± 0.9 4.7 ± 0.3
protoss_20_vs_20 4.5 ± 1.8 3.4 ± 1.9 7.0 ± 2.5 3.1 ± 3.8 5.2 ± 1.3 8.6 ± 2.0 5.2 ± 1.9 11.0 ± 3.2
protoss_20_vs_23 1.5 ± 0.8 0.8 ± 0.4 1.8 ± 0.8 0.8 ± 1.4 2.4 ± 0.9 2.0 ± 0.9 2.6 ± 1.0 3.8 ± 2.0
terran_5_vs_5 10.1 ± 2.3 12.8 ± 1.3 13.7 ± 3.5 10.2 ± 4.6 10.6 ± 1.3 9.7 ± 1.9 10.4 ± 3.2 14.2 ± 3.1
terran_10_vs_10 7.3 ± 2.0 7.7 ± 2.2 7.9 ± 2.4 9.4 ± 3.8 8.0 ± 2.4 8.2 ± 1.6 7.0 ± 1.9 12.0 ± 1.7
terran_10_vs_11 1.9 ± 0.8 3.1 ± 1.4 2.5 ± 1.1 0.8 ± 1.4 1.8 ± 0.5 2.6 ± 1.2 2.4 ± 0.8 4.2 ± 1.6
terran_20_vs_20 4.5 ± 1.3 7.4 ± 2.5 4.4 ± 0.6 4.7 ± 3.5 5.8 ± 0.7 7.1 ± 1.4 4.5 ± 1.5 8.8 ± 1.5
terran_20_vs_23 0.6 ± 1.0 1.2 ± 0.9 0.7 ± 0.7 0.0 ± 0.0 0.8 ± 0.4 1.3 ± 0.7 1.3 ± 1.2 1.8 ± 1.4
zerg_5_vs_5 6.2 ± 0.8 5.6 ± 1.1 5.9 ± 0.5 6.2 ± 5.8 6.0 ± 1.4 6.7 ± 1.1 7.0 ± 1.8 7.9 ± 1.0
zerg_10_vs_10 4.5 ± 1.4 6.9 ± 1.8 5.0 ± 1.8 3.9 ± 3.4 4.6 ± 1.0 5.7 ± 1.2 5.2 ± 0.6 8.9 ± 2.1
zerg_10_vs_11 5.7 ± 1.9 5.8 ± 1.9 5.2 ± 1.3 2.3 ± 1.4 4.7 ± 1.6 6.0 ± 1.9 3.4 ± 0.7 6.8 ± 1.1
zerg_20_vs_20 0.8 ± 0.9 1.5 ± 0.6 1.3 ± 0.8 0.0 ± 0.0 0.2 ± 0.2 1.2 ± 0.5 0.6 ± 0.2 2.4 ± 0.6
zerg_20_vs_23 1.2 ± 0.9 0.7 ± 0.3 1.2 ± 0.6 1.6 ± 1.6 1.4 ± 0.5 1.9 ± 0.9 1.8 ± 1.1 2.7 ± 0.4
In [123]:
def update_legend(fig, tag="returns", distance=1.1, yrange=None):
    trace_names = []
    for trace in fig.data:
        if trace.name is not None and trace.name not in trace_names:
            trace_names.append(trace.name)
            trace.update(showlegend=True)
        else:
            trace.update(showlegend=False)
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=distance, xanchor="right", x=1))

    fig.update_xaxes(showgrid=True)
    fig.update_yaxes(showgrid=True)

    fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=3))

    if tag == "winrates":
        fig.update_yaxes(tickformat=".0%", dtick=0.1, minor=dict(ticklen=3, nticks=2))
    else:
        fig.update_yaxes(tickformat="~s")
    if yrange is not None:
        fig.update_yaxes(range=yrange)
    return fig
In [124]:
def create_scatters(env_name, data_dict, y_range=None, tag="returns", algos=None, rename_algos=None):
    fig = go.Figure()
    if not algos:
        algos = MAIN_ALGOS_2
    if not rename_algos:
        rename_algos = RENAME_ALGOS
    for algo in algos:
        steps = []
        values = []
        stds = []
        for step in range(101):
            if step < 1:
                continue
            data = np.array(data_dict[algo][env_name][step])
            steps.append(step)
            values.append(data.mean())
            stds.append(data.std())
        values = smooth(values)

        uppers = [value + std for value, std in zip(values, stds)]
        lowers = [value - std for value, std in zip(values, stds)]

        if tag == "winrates":
            uppers = [min(1.0, value) for value in uppers]
            lowers = [max(0.0, value) for value in lowers]

        color = COLOR_MAPS.get(algo, colors[2])
        algo_name = rename_algos.get(algo, algo)
        opacity = 0.1
        line_width = 1.5
        fig.add_trace(go.Scatter(x=steps, y=values, mode="lines", name=algo_name, line_color=color, line_width=line_width))
        fig.add_trace(go.Scatter(x=steps+steps[::-1], y=uppers+lowers[::-1], fill="toself", fillcolor=color, line_color=color, opacity=opacity, line_width=1, showlegend=False))

    if tag == "winrates":
        tickformat = ".0%"
    else:
        tickformat = "~s"

    fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
    fig.update_layout(height=130, width=180)
    fig.update_xaxes(range=[0, 100], dtick=50, minor=dict(ticklen=3, nticks=4))
    fig.update_yaxes(range=y_range, tickformat=tickformat)
    fig.update_layout(showlegend=False)
    return fig
In [125]:
def show_fig(plotly_figs, env_names, n_rows, n_cols, height, width, title):
    tag = "returns" if "returns" in title.lower() else "winrates"
    prefix = "ablation_" if "ablation" in title.lower() else ""
    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=env_names)
    if any(env_name in SMACV2_ENV_NAMES for env_name in env_names):
        for i, mode in enumerate(["5_vs_5", "10_vs_10", "10_vs_11", "20_vs_20", "20_vs_23"]):
            for j, env_name in enumerate(["protoss", "terran", "zerg"]):
                plotly_fig = plotly_figs[f"{env_name}_{mode}"]
                plotly_fig.write_image(f"graphs/{prefix}{env_name}_{mode}_llm_{tag}.pdf")
                fig.add_traces(plotly_fig.data, rows=j+1, cols=i+1)
    else:
        for i, env_name in enumerate(env_names):
            plotly_fig = plotly_figs[env_name]
            plotly_fig.write_image(f"graphs/{prefix}{env_name}_llm_{tag}.pdf")
            fig.add_traces(plotly_fig.data, rows=1, cols=i+1)
    fig.update_layout(template='simple_white')
    fig.update_layout(height=height, width=width, title_text=title)
    fig = update_legend(fig, tag=tag, distance=1.15)
    fig.show("svg")
    fig.update_layout(title_text=None)
    title = title.replace(" - ", "_").replace(" ", "_").replace(":", "_").replace("__", "_")
    print(f"Saving {title}.pdf")
    fig.write_image(f"graphs/{title}.pdf")
In [126]:
plotly_figs = {}
for env_name in SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, data_returns_llm)

show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Returns - LLM-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Returns - LLM-based")
No description has been provided for this image
Saving SMACv1_Returns_LLM-based.pdf
No description has been provided for this image
Saving SMACv2_Returns_LLM-based.pdf
In [127]:
plotly_figs = {}
for env_name in SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, data_winrates_llm, tag="winrates")

show_fig(plotly_figs, SMACV1_ENV_NAMES, n_rows=1, n_cols=4, height=300, width=950, title="SMACv1 - Winrates - LLM-based")
show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="SMACv2 - Winrates - LLM-based")
No description has been provided for this image
Saving SMACv1_Winrates_LLM-based.pdf
No description has been provided for this image
Saving SMACv2_Winrates_LLM-based.pdf
In [128]:
def show_box_chart(pd_data, env_names, tag="returns"):
    pd_data = pd_data.map(lambda x: analyze_data(x, tag=None))
    fig = go.Figure()
    for i, exsize in enumerate([50, 200, 400, 800, 1200]):
        all_values = []
        for env_name in env_names:
            values = pd_data[exsize][env_name]
            all_values.extend(values)
        fig.add_trace(go.Box(y=all_values, name=exsize, marker_color=colors[0], boxpoints=False))

    if tag == "winrates":
        tickformat = ".0%"
    else:
        tickformat = "~s"
        
    fig.update_layout(template='simple_white', margin=dict(l=0, r=0, t=0, b=0, pad=0, autoexpand=True))
    fig.update_layout(height=160, width=240)
    fig.update_layout(showlegend=False)
    fig.update_yaxes(tickformat=tickformat)
    env_name = env_name.split("_")[0]
    fig.write_image(f"graphs/{env_name}_llm_{tag}_box.pdf")
    fig.show("svg")
In [129]:
ablation_exsize_returns_llm = {}
ablation_exsize_winrates_llm = {}
for exsize in [50, 200, 400, 800, 1200]:
    data_returns, data_winrates = load_data(use_llm=True, exsize=exsize, algos=["miso_alpha0.05"])
    ablation_exsize_returns_llm[exsize] = data_returns["miso_alpha0.05"]
    ablation_exsize_winrates_llm[exsize] = data_winrates["miso_alpha0.05"]
In [130]:
topk_names = {exsize: f"topk={exsize}" for exsize in [50, 200, 400, 800, 1200]}
PROTOSS_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "protoss" in env_name]
TERRAN_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "terran" in env_name]
ZERG_ENVS = [env_name for env_name in SMACV2_ENV_NAMES if "zerg" in env_name]
In [131]:
pd_returns_llm = pd.DataFrame.from_dict(ablation_exsize_returns_llm)
pd_returns_llm.rename(columns=topk_names).map(analyze_data)
Out[131]:
topk=50 topk=200 topk=400 topk=800 topk=1200
2c_vs_64zg 11.8 ± 1.4 16.4 ± 1.3 9.7 ± 0.1 10.8 ± 0.7 10.4 ± 0.5
5m_vs_6m 6.9 ± 0.4 7.3 ± 0.1 6.8 ± 0.2 6.5 ± 0.1 6.4 ± 0.2
6h_vs_8z 7.9 ± 0.1 8.7 ± 0.2 7.8 ± 0.2 7.5 ± 0.1 7.3 ± 0.1
corridor 3.1 ± 1.0 5.8 ± 0.8 1.9 ± 0.1 1.7 ± 0.2 1.7 ± 0.1
protoss_5_vs_5 11.9 ± 0.2 12.4 ± 0.5 11.3 ± 0.2 10.9 ± 0.2 10.9 ± 0.5
protoss_10_vs_10 11.8 ± 0.1 12.9 ± 0.2 12.1 ± 0.3 11.5 ± 0.3 11.2 ± 0.4
protoss_10_vs_11 9.6 ± 0.1 10.7 ± 0.4 9.5 ± 0.3 9.3 ± 0.5 9.2 ± 0.3
protoss_20_vs_20 11.4 ± 0.5 13.5 ± 0.5 12.1 ± 0.1 11.4 ± 0.3 11.2 ± 0.3
protoss_20_vs_23 9.8 ± 0.3 10.6 ± 0.2 10.0 ± 0.2 9.3 ± 0.2 8.7 ± 0.4
terran_5_vs_5 8.0 ± 0.3 9.1 ± 0.3 7.3 ± 0.4 7.7 ± 0.4 7.0 ± 0.5
terran_10_vs_10 8.4 ± 0.4 9.1 ± 1.3 8.1 ± 0.5 7.6 ± 0.7 7.0 ± 0.4
terran_10_vs_11 5.5 ± 0.1 6.4 ± 0.5 6.1 ± 0.2 5.7 ± 0.4 5.4 ± 0.5
terran_20_vs_20 8.0 ± 0.4 9.2 ± 0.6 8.2 ± 0.6 8.0 ± 0.5 7.9 ± 0.5
terran_20_vs_23 5.1 ± 0.2 5.6 ± 0.4 5.1 ± 0.2 5.2 ± 0.2 4.8 ± 0.2
zerg_5_vs_5 7.4 ± 0.4 7.5 ± 0.1 6.9 ± 0.5 6.4 ± 0.5 6.2 ± 0.5
zerg_10_vs_10 9.2 ± 0.3 10.2 ± 0.6 9.0 ± 0.3 8.8 ± 0.2 8.3 ± 0.2
zerg_10_vs_11 8.5 ± 0.4 9.4 ± 0.3 8.5 ± 0.0 8.2 ± 0.2 7.6 ± 0.3
zerg_20_vs_20 8.9 ± 0.1 10.2 ± 0.6 8.8 ± 0.9 8.3 ± 0.5 8.1 ± 0.6
zerg_20_vs_23 8.4 ± 0.4 9.5 ± 0.2 8.8 ± 0.2 8.5 ± 0.2 8.0 ± 0.1
In [132]:
pd_winrates_llm = pd.DataFrame.from_dict(ablation_exsize_winrates_llm)
pd_winrates_llm.rename(columns=topk_names).map(lambda x: analyze_data(x, tag="winrates"))
Out[132]:
topk=50 topk=200 topk=400 topk=800 topk=1200
2c_vs_64zg 5.6 ± 4.9 13.0 ± 9.0 1.6 ± 0.7 1.7 ± 0.8 0.6 ± 0.7
5m_vs_6m 0.0 ± 0.0 1.2 ± 0.5 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
6h_vs_8z 0.0 ± 0.0 1.1 ± 0.8 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
corridor 0.0 ± 0.0 1.4 ± 0.6 0.0 ± 0.0 0.0 ± 0.0 0.0 ± 0.0
protoss_5_vs_5 8.9 ± 1.9 20.7 ± 0.9 10.1 ± 2.8 12.5 ± 1.7 9.4 ± 1.7
protoss_10_vs_10 5.3 ± 1.8 14.1 ± 2.1 11.1 ± 3.3 10.0 ± 3.0 5.7 ± 1.6
protoss_10_vs_11 0.3 ± 0.4 4.7 ± 0.3 2.2 ± 1.2 2.0 ± 1.0 1.6 ± 0.6
protoss_20_vs_20 2.8 ± 0.8 11.0 ± 3.2 3.6 ± 1.4 6.1 ± 1.9 3.4 ± 0.9
protoss_20_vs_23 0.6 ± 0.4 3.8 ± 2.0 2.2 ± 0.8 1.8 ± 1.2 1.0 ± 0.4
terran_5_vs_5 8.5 ± 1.9 14.2 ± 3.1 9.8 ± 2.1 10.4 ± 2.1 5.1 ± 1.1
terran_10_vs_10 7.7 ± 0.6 12.0 ± 1.7 7.9 ± 2.1 8.1 ± 3.8 5.2 ± 0.3
terran_10_vs_11 0.8 ± 0.6 4.2 ± 1.6 2.2 ± 1.2 1.4 ± 0.9 0.6 ± 0.4
terran_20_vs_20 3.7 ± 0.4 8.8 ± 1.5 4.3 ± 0.6 6.0 ± 1.6 5.2 ± 1.2
terran_20_vs_23 0.5 ± 0.3 1.8 ± 1.4 0.6 ± 0.6 0.3 ± 0.4 0.1 ± 0.1
zerg_5_vs_5 5.6 ± 1.3 7.9 ± 1.0 5.3 ± 1.1 4.7 ± 1.6 3.8 ± 0.6
zerg_10_vs_10 3.4 ± 1.7 8.9 ± 2.1 5.6 ± 1.5 6.7 ± 1.1 3.5 ± 1.0
zerg_10_vs_11 2.5 ± 1.0 6.8 ± 1.1 4.6 ± 1.6 4.5 ± 1.3 2.8 ± 0.6
zerg_20_vs_20 0.3 ± 0.3 2.4 ± 0.6 0.3 ± 0.3 0.4 ± 0.5 0.3 ± 0.3
zerg_20_vs_23 0.6 ± 0.3 2.7 ± 0.4 1.4 ± 0.7 1.1 ± 1.0 1.2 ± 0.7
In [133]:
show_box_chart(pd_returns_llm, PROTOSS_ENVS, tag="returns")
show_box_chart(pd_returns_llm, TERRAN_ENVS, tag="returns")
show_box_chart(pd_returns_llm, ZERG_ENVS, tag="returns")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [134]:
show_box_chart(pd_winrates_llm, PROTOSS_ENVS, tag="winrates")
show_box_chart(pd_winrates_llm, TERRAN_ENVS, tag="winrates")
show_box_chart(pd_winrates_llm, ZERG_ENVS, tag="winrates")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [135]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_returns_llm, algos=[50, 200, 400, 800, 1200])

show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Returns - Rule-based")
No description has been provided for this image
Saving Ablation_SMACv2_Returns_Rule-based.pdf
In [136]:
plotly_figs = {}
for env_name in SMACV2_ENV_NAMES:
    plotly_figs[env_name] = create_scatters(env_name, ablation_exsize_winrates_llm, algos=[50, 200, 400, 800, 1200], tag="winrates")

show_fig(plotly_figs, SMACV2_ENV_NAMES, n_rows=3, n_cols=5, height=640, width=1200, title="Ablation: SMACv2 - Winrates - Rule-based")
No description has been provided for this image
Saving Ablation_SMACv2_Winrates_Rule-based.pdf
In [137]:
RENAME_ALGOS_GPT = {
    "miso-mini": "MisoDICE (gpt-4o-mini)",
    "miso_alpha0.05": "MisoDICE (gpt-4o)",
}
In [138]:
data_returns_llm, data_winrates_llm = load_data(use_llm=True, algos=["miso-mini", "miso_alpha0.05"])
pd_returns_llm = pd.DataFrame.from_dict(data_returns_llm).rename(columns=RENAME_ALGOS_GPT)
pd_winrates_llm = pd.DataFrame.from_dict(data_winrates_llm).rename(columns=RENAME_ALGOS_GPT)
In [139]:
pd_returns_llm.map(analyze_data)
Out[139]:
MisoDICE (gpt-4o-mini) MisoDICE (gpt-4o)
2c_vs_64zg 12.2 ± 0.6 16.4 ± 1.3
5m_vs_6m 6.7 ± 0.9 7.3 ± 0.1
6h_vs_8z 8.1 ± 0.2 8.7 ± 0.2
corridor 3.8 ± 0.8 5.8 ± 0.8
protoss_5_vs_5 12.0 ± 0.3 12.4 ± 0.5
protoss_10_vs_10 12.0 ± 0.5 12.9 ± 0.2
protoss_10_vs_11 10.6 ± 0.3 10.7 ± 0.4
protoss_20_vs_20 12.4 ± 0.5 13.5 ± 0.5
protoss_20_vs_23 10.3 ± 0.1 10.6 ± 0.2
terran_5_vs_5 8.1 ± 0.5 9.1 ± 0.3
terran_10_vs_10 8.6 ± 0.9 9.1 ± 1.3
terran_10_vs_11 6.0 ± 0.3 6.4 ± 0.5
terran_20_vs_20 8.1 ± 0.5 9.2 ± 0.6
terran_20_vs_23 5.3 ± 0.3 5.6 ± 0.4
zerg_5_vs_5 7.0 ± 0.4 7.5 ± 0.1
zerg_10_vs_10 9.6 ± 0.2 10.2 ± 0.6
zerg_10_vs_11 9.2 ± 0.6 9.4 ± 0.3
zerg_20_vs_20 8.9 ± 0.4 10.2 ± 0.6
zerg_20_vs_23 9.0 ± 0.2 9.5 ± 0.2
In [140]:
pd_winrates_llm.map(lambda x: analyze_data(x, tag="winrates"))
Out[140]:
MisoDICE (gpt-4o-mini) MisoDICE (gpt-4o)
2c_vs_64zg 2.1 ± 1.1 13.0 ± 9.0
5m_vs_6m 1.2 ± 0.4 1.2 ± 0.5
6h_vs_8z 0.1 ± 0.1 1.1 ± 0.8
corridor 0.6 ± 0.7 1.4 ± 0.6
protoss_5_vs_5 12.6 ± 2.6 20.7 ± 0.9
protoss_10_vs_10 10.5 ± 3.5 14.1 ± 2.1
protoss_10_vs_11 4.0 ± 1.8 4.7 ± 0.3
protoss_20_vs_20 6.0 ± 0.9 11.0 ± 3.2
protoss_20_vs_23 2.3 ± 1.4 3.8 ± 2.0
terran_5_vs_5 10.0 ± 1.4 14.2 ± 3.1
terran_10_vs_10 9.2 ± 2.1 12.0 ± 1.7
terran_10_vs_11 2.2 ± 1.1 4.2 ± 1.6
terran_20_vs_20 6.1 ± 2.1 8.8 ± 1.5
terran_20_vs_23 0.9 ± 0.6 1.8 ± 1.4
zerg_5_vs_5 5.3 ± 1.1 7.9 ± 1.0
zerg_10_vs_10 7.1 ± 0.9 8.9 ± 2.1
zerg_10_vs_11 5.0 ± 0.9 6.8 ± 1.1
zerg_20_vs_20 0.8 ± 0.8 2.4 ± 0.6
zerg_20_vs_23 1.7 ± 0.8 2.7 ± 0.4
In [141]:
RENAME_ALGOS_2 = {
    "miso_alpha0.00": "MisoDICE (α=0.00)",
    "miso_alpha0.05": "MisoDICE (α=0.05)",
    "miso_alpha0.10": "MisoDICE (α=0.10)",
    "miso_alpha0.50": "MisoDICE (α=0.50)",
    "miso_alpha1.00": "MisoDICE (α=1.00)",
    "miso_alpha10.00": "MisoDICE (α=10.0)",
}
data_returns_llm, data_winrates_llm = load_data(use_llm=True, algos=ABLATION_ALGOS)
pd_returns_llm = pd.DataFrame.from_dict(data_returns_llm).rename(columns=RENAME_ALGOS_2)
pd_winrates_llm = pd.DataFrame.from_dict(data_winrates_llm).rename(columns=RENAME_ALGOS_2)
In [142]:
pd_returns_llm.map(analyze_data)
Out[142]:
MisoDICE (α=0.00) MisoDICE (α=0.05) MisoDICE (α=0.10) MisoDICE (α=0.50) MisoDICE (α=1.00) MisoDICE (α=10.0)
2c_vs_64zg 15.9 ± 1.0 16.4 ± 1.3 15.3 ± 1.1 14.9 ± 0.7 13.3 ± 0.9 10.9 ± 0.3
5m_vs_6m 6.6 ± 1.4 7.3 ± 0.1 6.3 ± 1.0 6.5 ± 0.5 6.9 ± 0.2 6.7 ± 0.1
6h_vs_8z 8.3 ± 0.2 8.7 ± 0.2 8.0 ± 0.2 8.0 ± 0.2 7.7 ± 0.2 7.6 ± 0.1
corridor 5.3 ± 0.6 5.8 ± 0.8 5.3 ± 0.6 5.1 ± 0.2 3.8 ± 0.9 1.7 ± 0.1
protoss_5_vs_5 12.4 ± 0.4 12.4 ± 0.5 11.7 ± 0.4 11.8 ± 0.4 11.9 ± 0.3 11.7 ± 0.4
protoss_10_vs_10 12.1 ± 0.2 12.9 ± 0.2 12.2 ± 0.6 12.1 ± 0.3 11.6 ± 0.5 12.1 ± 0.4
protoss_10_vs_11 10.7 ± 0.5 10.7 ± 0.4 10.1 ± 0.5 10.1 ± 0.4 9.8 ± 0.3 9.9 ± 0.1
protoss_20_vs_20 12.4 ± 0.2 13.5 ± 0.5 12.3 ± 0.4 12.2 ± 0.5 11.8 ± 0.2 12.1 ± 0.4
protoss_20_vs_23 10.4 ± 0.4 10.6 ± 0.2 10.1 ± 0.3 10.1 ± 0.2 10.0 ± 0.2 9.6 ± 0.5
terran_5_vs_5 8.8 ± 0.7 9.1 ± 0.3 8.4 ± 0.8 8.5 ± 0.2 7.9 ± 0.4 8.2 ± 0.3
terran_10_vs_10 8.9 ± 0.8 9.1 ± 1.3 8.5 ± 0.6 8.6 ± 0.9 7.8 ± 0.7 8.1 ± 0.7
terran_10_vs_11 5.9 ± 0.4 6.4 ± 0.5 5.8 ± 0.3 5.9 ± 0.4 5.7 ± 0.5 5.8 ± 0.6
terran_20_vs_20 9.0 ± 0.6 9.2 ± 0.6 8.5 ± 0.7 8.4 ± 0.6 8.2 ± 0.6 8.4 ± 0.7
terran_20_vs_23 5.6 ± 0.2 5.6 ± 0.4 5.3 ± 0.3 5.1 ± 0.3 5.1 ± 0.3 5.1 ± 0.1
zerg_5_vs_5 7.3 ± 0.5 7.5 ± 0.1 7.2 ± 0.6 6.8 ± 0.2 6.4 ± 0.3 6.4 ± 0.3
zerg_10_vs_10 9.7 ± 0.4 10.2 ± 0.6 9.2 ± 0.3 9.2 ± 0.4 8.7 ± 0.2 8.9 ± 0.4
zerg_10_vs_11 9.0 ± 0.5 9.4 ± 0.3 8.5 ± 0.2 8.8 ± 0.3 8.7 ± 0.4 8.3 ± 0.5
zerg_20_vs_20 9.0 ± 0.4 10.2 ± 0.6 9.0 ± 0.3 9.1 ± 0.7 8.8 ± 0.4 8.9 ± 0.4
zerg_20_vs_23 8.9 ± 0.2 9.5 ± 0.2 8.8 ± 0.4 8.9 ± 0.3 8.5 ± 0.4 8.6 ± 0.4
In [143]:
pd_winrates_llm.map(lambda x: analyze_data(x, tag="winrates"))
Out[143]:
MisoDICE (α=0.00) MisoDICE (α=0.05) MisoDICE (α=0.10) MisoDICE (α=0.50) MisoDICE (α=1.00) MisoDICE (α=10.0)
2c_vs_64zg 11.0 ± 5.8 13.0 ± 9.0 10.0 ± 5.0 9.3 ± 2.9 7.1 ± 4.1 1.7 ± 0.9
5m_vs_6m 1.1 ± 1.0 1.2 ± 0.5 0.6 ± 0.4 0.7 ± 0.4 0.9 ± 0.7 0.7 ± 0.5
6h_vs_8z 1.0 ± 0.8 1.1 ± 0.8 1.1 ± 0.8 0.8 ± 0.5 0.4 ± 0.3 1.0 ± 0.5
corridor 0.4 ± 0.2 1.4 ± 0.6 1.2 ± 0.5 0.5 ± 0.4 1.4 ± 0.7 0.9 ± 0.4
protoss_5_vs_5 12.1 ± 1.6 20.7 ± 0.9 14.4 ± 2.1 12.9 ± 2.0 14.1 ± 2.2 17.2 ± 2.5
protoss_10_vs_10 8.4 ± 3.3 14.1 ± 2.1 9.9 ± 2.2 9.5 ± 2.6 9.0 ± 1.7 11.1 ± 2.3
protoss_10_vs_11 3.3 ± 1.1 4.7 ± 0.3 4.3 ± 0.4 3.5 ± 1.3 3.5 ± 1.1 4.3 ± 1.2
protoss_20_vs_20 5.4 ± 0.4 11.0 ± 3.2 5.5 ± 1.1 6.4 ± 1.0 6.4 ± 0.7 7.8 ± 2.4
protoss_20_vs_23 4.3 ± 0.6 3.8 ± 2.0 2.3 ± 1.4 3.5 ± 1.6 3.3 ± 0.8 2.5 ± 1.3
terran_5_vs_5 12.4 ± 2.9 14.2 ± 3.1 11.2 ± 0.8 12.1 ± 1.6 12.0 ± 3.2 12.2 ± 3.2
terran_10_vs_10 9.3 ± 1.2 12.0 ± 1.7 8.6 ± 1.9 9.2 ± 3.6 8.8 ± 1.9 9.3 ± 3.2
terran_10_vs_11 2.3 ± 0.4 4.2 ± 1.6 3.5 ± 1.2 2.3 ± 0.7 2.2 ± 0.8 3.6 ± 1.5
terran_20_vs_20 4.8 ± 1.6 8.8 ± 1.5 4.9 ± 1.5 5.9 ± 2.5 4.9 ± 1.9 9.3 ± 2.1
terran_20_vs_23 1.5 ± 0.7 1.8 ± 1.4 1.2 ± 0.7 1.8 ± 0.7 1.2 ± 0.9 1.7 ± 1.0
zerg_5_vs_5 7.1 ± 1.6 7.9 ± 1.0 7.4 ± 1.0 5.9 ± 0.7 6.4 ± 2.7 7.5 ± 2.1
zerg_10_vs_10 6.7 ± 1.2 8.9 ± 2.1 7.2 ± 2.6 7.0 ± 0.8 7.1 ± 0.9 6.9 ± 1.0
zerg_10_vs_11 5.0 ± 1.4 6.8 ± 1.1 5.0 ± 1.5 4.1 ± 1.9 5.0 ± 1.2 4.5 ± 1.4
zerg_20_vs_20 1.3 ± 0.5 2.4 ± 0.6 1.2 ± 0.4 1.7 ± 2.1 1.6 ± 0.3 2.6 ± 0.6
zerg_20_vs_23 1.4 ± 0.5 2.7 ± 0.4 1.0 ± 0.4 2.5 ± 0.7 2.7 ± 1.7 2.1 ± 0.7
In [144]:
MAIN_ALGOS_TMP = ["bc_beta0.0", "bc_beta0.5", "bc_beta1.0", "indd", "vdn", "miso_alpha0.05"]
In [145]:
def load_data_tmp(use_llm=False, exsize=200, algos=None):
    data_returns = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    data_winrates = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    if not algos:
        algos = ALGOS
    for algo in algos:
        for env_name in MAMUJOCO_ENV_NAMES + SMACV1_ENV_NAMES + SMACV2_ENV_NAMES:
            for seed in range(4):
                if use_llm:
                    if env_name in MAMUJOCO_ENV_NAMES:
                        continue
                    path = f"logs_ablation/{algo}/{env_name}_llm/seed{seed}/exsize{exsize}/results.json"
                else:
                    path = f"logs_ablation/{algo}/{env_name}/seed{seed}/exsize{exsize}/results.json"
                
                if algo == "omapl":
                    path = f"logs_ablation/{algo}/{env_name}_llm/seed{seed}/results.json"

                data = load_results(path)
                for step in range(101):
                    if step != 100:
                        continue
                    result = data[f"step_{step}"] if algo != "omapl" else data
                    if "returns" in result:
                        data_returns[algo][env_name][step].append(np.mean(result["returns"]))
                    if "winrates" in result:
                        data_winrates[algo][env_name][step].append(np.mean(result["winrates"]))
    return data_returns, data_winrates
In [146]:
data_returns_llm, data_winrates_llm = load_data_tmp(use_llm=True, algos=MAIN_ALGOS_TMP)
pd_returns_llm = pd.DataFrame.from_dict(data_returns_llm).rename(columns=RENAME_ALGOS)
pd_winrates_llm = pd.DataFrame.from_dict(data_winrates_llm).rename(columns=RENAME_ALGOS)
In [147]:
pd_returns_llm.map(analyze_data)
Out[147]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) INDD VDN MisoDICE (ours)
2c_vs_64zg 11.7 ± 0.4 11.3 ± 0.1 13.1 ± 0.8 15.1 ± 1.3 15.9 ± 2.0 16.1 ± 1.8
5m_vs_6m 6.1 ± 1.4 6.8 ± 1.4 7.2 ± 0.2 7.4 ± 0.2 7.4 ± 0.3 7.4 ± 0.1
6h_vs_8z 8.4 ± 0.1 8.4 ± 0.2 8.3 ± 0.3 8.2 ± 0.3 8.5 ± 0.1 8.9 ± 0.1
corridor 2.1 ± 0.3 1.9 ± 0.3 5.7 ± 1.5 2.0 ± 0.2 5.8 ± 1.5 6.1 ± 1.6
protoss_5_vs_5 13.3 ± 0.1 12.6 ± 2.8 12.0 ± 1.2 11.9 ± 1.9 13.3 ± 0.5 13.7 ± 1.2
protoss_10_vs_10 13.1 ± 0.9 12.5 ± 1.4 12.8 ± 0.9 12.3 ± 0.8 13.2 ± 0.5 14.0 ± 1.2
protoss_10_vs_11 10.8 ± 1.3 10.6 ± 0.9 10.9 ± 0.8 10.1 ± 1.1 11.4 ± 0.5 12.0 ± 1.3
protoss_20_vs_20 12.4 ± 1.4 12.2 ± 0.5 13.0 ± 0.4 12.7 ± 0.6 13.3 ± 1.4 13.9 ± 1.6
protoss_20_vs_23 10.1 ± 1.5 9.5 ± 0.5 10.2 ± 0.8 10.1 ± 0.9 10.4 ± 0.8 10.5 ± 1.1
terran_5_vs_5 8.9 ± 0.9 8.4 ± 1.0 8.7 ± 0.9 8.9 ± 1.8 9.4 ± 2.3 9.4 ± 1.5
terran_10_vs_10 7.5 ± 1.0 8.2 ± 1.0 7.7 ± 1.2 7.9 ± 0.7 8.9 ± 1.3 9.0 ± 0.6
terran_10_vs_11 6.3 ± 0.4 5.9 ± 1.4 5.8 ± 0.4 5.9 ± 0.7 7.1 ± 1.2 7.3 ± 1.0
terran_20_vs_20 8.6 ± 1.2 8.3 ± 0.8 7.9 ± 0.2 8.3 ± 1.0 8.7 ± 1.0 9.3 ± 1.2
terran_20_vs_23 5.1 ± 0.2 5.5 ± 0.5 5.4 ± 0.3 4.8 ± 0.4 5.6 ± 0.5 5.9 ± 0.7
zerg_5_vs_5 7.8 ± 1.2 7.0 ± 0.9 7.1 ± 1.4 6.8 ± 0.3 7.8 ± 0.8 8.0 ± 0.6
zerg_10_vs_10 8.7 ± 1.2 8.2 ± 0.5 9.2 ± 0.8 9.6 ± 1.2 9.6 ± 1.7 9.9 ± 0.5
zerg_10_vs_11 9.6 ± 0.4 9.3 ± 1.5 8.8 ± 1.2 9.3 ± 1.0 9.8 ± 0.9 9.9 ± 0.9
zerg_20_vs_20 9.2 ± 0.8 9.3 ± 0.6 9.4 ± 1.1 9.2 ± 0.7 9.5 ± 0.8 11.1 ± 1.3
zerg_20_vs_23 9.2 ± 1.1 8.9 ± 0.5 8.2 ± 0.6 9.3 ± 1.0 9.5 ± 0.3 9.5 ± 0.1
In [148]:
pd_winrates_llm.map(lambda x: analyze_data(x, tag="winrates"))
Out[148]:
BC (β=0.0) BC (β=0.5) BC (β=1.0) INDD VDN MisoDICE (ours)
2c_vs_64zg 1.6 ± 1.6 0.0 ± 0.0 5.5 ± 2.6 9.4 ± 2.2 10.2 ± 7.1 11.7 ± 8.1
5m_vs_6m 0.8 ± 1.4 1.6 ± 1.6 0.8 ± 1.4 0.8 ± 1.4 1.6 ± 1.6 1.6 ± 1.6
6h_vs_8z 0.0 ± 0.0 0.8 ± 1.4 0.0 ± 0.0 0.0 ± 0.0 1.6 ± 1.6 2.3 ± 2.6
corridor 0.8 ± 1.4 0.8 ± 1.4 0.8 ± 1.4 0.8 ± 1.4 1.6 ± 2.7 2.3 ± 2.6
protoss_5_vs_5 13.3 ± 3.4 17.2 ± 8.4 20.3 ± 4.7 17.2 ± 3.5 20.3 ± 6.8 21.1 ± 11.6
protoss_10_vs_10 10.2 ± 3.4 10.2 ± 3.4 8.6 ± 6.0 6.2 ± 2.2 10.2 ± 2.6 15.6 ± 4.4
protoss_10_vs_11 1.6 ± 1.6 3.9 ± 1.4 3.1 ± 3.1 4.7 ± 3.5 4.7 ± 1.6 6.2 ± 2.2
protoss_20_vs_20 8.6 ± 6.4 2.3 ± 2.6 9.4 ± 2.2 7.0 ± 7.8 9.4 ± 2.2 10.9 ± 5.2
protoss_20_vs_23 2.3 ± 2.6 1.6 ± 1.6 1.6 ± 1.6 2.3 ± 1.4 3.1 ± 3.8 3.9 ± 5.1
terran_5_vs_5 10.9 ± 3.5 7.0 ± 4.1 9.4 ± 3.8 11.7 ± 4.1 16.4 ± 7.5 17.2 ± 8.4
terran_10_vs_10 8.6 ± 2.6 7.0 ± 4.1 7.8 ± 3.5 5.5 ± 2.6 9.4 ± 3.8 10.9 ± 3.5
terran_10_vs_11 1.6 ± 1.6 1.6 ± 2.7 2.3 ± 1.4 1.6 ± 2.7 4.7 ± 3.5 5.5 ± 2.6
terran_20_vs_20 7.0 ± 2.6 6.2 ± 2.2 4.7 ± 3.5 4.7 ± 1.6 8.6 ± 5.1 9.4 ± 2.2
terran_20_vs_23 0.8 ± 1.4 0.0 ± 0.0 1.6 ± 2.7 0.8 ± 1.4 1.6 ± 1.6 2.3 ± 1.4
zerg_5_vs_5 6.2 ± 2.2 5.5 ± 1.4 3.9 ± 2.6 6.2 ± 3.8 7.0 ± 4.6 7.8 ± 4.7
zerg_10_vs_10 7.8 ± 2.7 5.5 ± 3.4 5.5 ± 4.6 3.1 ± 2.2 8.6 ± 6.8 8.6 ± 8.1
zerg_10_vs_11 3.9 ± 1.4 4.7 ± 3.5 0.8 ± 1.4 3.9 ± 2.6 8.6 ± 3.4 8.6 ± 4.6
zerg_20_vs_20 1.6 ± 2.7 1.6 ± 1.6 0.8 ± 1.4 0.0 ± 0.0 3.9 ± 2.6 2.3 ± 2.6
zerg_20_vs_23 0.8 ± 1.4 1.6 ± 1.6 1.6 ± 1.6 0.0 ± 0.0 1.6 ± 2.7 1.6 ± 2.7