import time
from glob import glob
import numpy as np
import math
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
params = {'legend.fontsize': 24, #24, 28
          'axes.labelsize': 28,
          'axes.titlesize': 28,
          'xtick.labelsize': 20,
          'ytick.labelsize': 20,
          'lines.linewidth': 2.5
          }
plt.rcParams.update(params)

from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib.ticker import FuncFormatter


output_dir = "outputs"

class AsymScale(mscale.ScaleBase):
    name = 'asym'

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self, axis)
        self.a = kwargs.get("a", 1)

    def get_transform(self):
        return self.AsymTrans(self.a)

    def set_default_locators_and_formatters(self, axis):
        # possibly, set a different locator and formatter here.
        fmt = lambda x,pos: "{}".format(np.abs(x))
        axis.set_major_formatter(FuncFormatter(fmt))

    class AsymTrans(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, a):
            mtransforms.Transform.__init__(self)
            self.a = a

        def transform_non_affine(self, x):
            return (x >= 0)*x + (x < 0)*x*self.a

        def inverted(self):
            return AsymScale.InvertedAsymTrans(self.a)

    class InvertedAsymTrans(AsymTrans):

        def transform_non_affine(self, x):
            return (x >= 0)*x + (x < 0)*x/self.a
        def inverted(self):
            return AsymScale.AsymTrans(self.a)

mscale.register_scale(AsymScale)

def flip(items, ncol):
    return itertools.chain(*[items[i::ncol] for i in range(ncol)])


def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value

    return smoothed


def is_continuous_index(data):
    flag = True
    for i in range(len(data)-1):
        if data[i+1] - data[i] != 1:
            # print("Index jump: ", i,  data[i],  data[i+1])
            flag = False
    return flag


def get_sums(data):
    xs = np.zeros(len(data))
    for i in range(len(data)):
        xs[i] = data[i] if i == 0 else xs[i-1] + data[i]
    return xs


def get_multi_runs_mean_var_new(ax, folder, num, frames, label, color, pick_top=False, use_label=False, pick_middle=False):
    stime = time.time()
    test = False
    if not test:
        folders = glob(folder + "*/", recursive=False)
        rewards = []
        xs = np.array(range(frames))
        data_to_plot = []
        for f in folders:
            file_name = f + "log.csv"
            data = np.genfromtxt(file_name, delimiter=',')

            # clean the data
            sorted_data = data[data[:,0].argsort()]
            _, unique_indices = np.unique(sorted_data[:,0], return_index=True)
            data = sorted_data[unique_indices]
            print("data cleaned up!")

            # check if the episode index contains all continuous numbers
            flag = set(data[:, 0]) == set(np.arange(data[-1, 0] + 1))
            print("Continuous index: ", flag)
            
            # get a new array in the shape of (len(data), 2)
            # the first column is the index of the steps
            # the second column is the average of the rewards
            new_data = np.zeros((len(data), 2))
            new_data[:, 0] = np.cumsum(data[:, 1])
            new_data[:, 1] = np.mean(data[:, 2: 2+num], axis=1)
            data_to_plot.append(new_data)

        steps = np.array(sorted(set(np.hstack([cur_data[:, 0] for cur_data in data_to_plot]))))
        values = np.zeros((len(data_to_plot), len(steps)))
        for i, cur_data in enumerate(data_to_plot):
            values[i, np.in1d(steps, cur_data[:, 0])] = cur_data[:, 1]

            ids = np.nonzero(np.in1d(steps, cur_data[:, 0]))[0]
            for k in range(len(ids) - 1):
                sid = ids[k]
                eid = ids[k+1]
                values[i, sid:eid] = values[i, sid]
            values[i, eid:] = values[i, eid]

        if pick_top:
            idx = np.argsort(values[:, -1].squeeze())
            values = values[idx[5:]]
        if pick_middle:
            idx = np.argsort(values[:, -1].squeeze())
            sid = math.floor(len(idx) * 0.25)
            eid = math.ceil(len(idx) * 0.75)
            if len(idx) == 10:
                values = values[idx[sid:eid]]
    else:
        # For test only
        steps = list(range(1, 4900000, 100))
        values = np.random.rand(10, len(xs)) * 10

    # breakpoint()
    index = np.arange(0, len(steps), 1000)
    steps = steps[index]
    values = values[:, index]

    rmean = np.array(smooth(np.mean(values, axis=0), 0.99))
    rstd = 0.5 * np.array(smooth(np.std(values, axis=0), 0.99))

    ax.plot(steps, rmean, color=color, label=label if use_label else None)
    ax.fill_between(steps, rmean + rstd, rmean - rstd, color=color, alpha=0.1)
    ax.set_aspect("auto")

    print(f"New Plot Time: {time.time() - stime}")


def get_multi_runs_mean_var(ax, folder, num, frames, label, color, pick_top=False, use_label=False, pick_middle=False):
    stime = time.time()
    test = False
    if not test:
        folders = glob(folder + "*/", recursive=False)
        rewards = []
        xs = np.array(range(frames))
        for f in folders:
            file_name = f + "log.csv"
            data = np.genfromtxt(file_name, delimiter=',')

            # clean the data
            if not "ata" in folder:
                sorted_data = data[data[:,0].argsort()]
                _, unique_indices = np.unique(sorted_data[:,0], return_index=True)
                data = sorted_data[unique_indices]
                print(f"data cleaned up! Max index: {data[-1, 0]}")

            # check if the index contains all continuous numbers
            flag = set(data[:, 0]) == set(np.arange(data[-1, 0] + 1))
            print("Continuous index: ", flag)

            cur_rewards = np.zeros(frames)
            cur_f = 0
            if data.shape[1] - 2 == 1:
                cur_r = data[0, 2]
            else:
                cur_r = data[0, 2: 2+num].mean()
            for i in range(len(data)):
                d = data[i]
                last_f = cur_f
                cur_f += int(d[1])
                cur_rewards[last_f:cur_f] = cur_r
                if len(d) - 2 == 1:
                    cur_r = d[2]
                else:
                    cur_r = d[2: 2 + num].mean()
            cur_rewards[cur_f:] = cur_r
            rewards.append(cur_rewards)

        step = 1000
        xs = xs[range(0, frames, step)]
        rewards = np.array([rs[range(0, frames, step)] for rs in rewards])

        if pick_top:
            idx = np.argsort(rewards[:, -1].squeeze())
            rewards = rewards[idx[5:]]
        if pick_middle:
            idx = np.argsort(rewards[:, -1].squeeze())
            sid = math.floor(len(idx) * 0.25)
            eid = math.ceil(len(idx) * 0.75)
            if len(idx) == 10:
                rewards = rewards[idx[sid:eid]]
    else:
        # For test only
        xs = list(range(1, 4900000, 100))
        rewards = np.random.rand(10, len(xs)) * 10

    rmean = np.array(smooth(np.mean(rewards, axis=0), 0.99))
    rstd = 0.5 * np.array(smooth(np.std(rewards, axis=0), 0.99))

    ax.plot(xs, rmean, color=color, label=label if use_label else None)
    ax.fill_between(xs, rmean + rstd, rmean - rstd, color=color, alpha=0.1)
    ax.set_aspect("auto")
    print(f"Time: {time.time() - stime}")


def plot_multi_runs_mean_var(folder, num, label, color):
    folders = glob(folder + "*/", recursive=False)
    rewards = []
    xs = None
    for f in folders:
        file_name = f + "log.csv"
        data = np.genfromtxt(file_name, delimiter=',')
        rewards.append(data[:, 2: 2+num].mean(axis=1))
        if xs is None or len(data[:, 0]) < len(xs):
            xs = data[:, 0]
    curl = len(xs)
    rewards = np.array([rs[:curl] for rs in rewards])
    rmean = np.array(smooth(np.mean(rewards, axis=0), 0.99))
    rstd = 0.5 * np.array(smooth(np.std(rewards, axis=0), 0.99))

    plt.plot(xs, rmean, color=color, label=label)
    plt.fill_between(xs, rmean + rstd, rmean - rstd, color='b', alpha=0.1)

    return xs


def plot_multi_runs(folder, run_num):
    fig, axs = plt.subplots(2, 5, figsize=(30, 8))

    for i in range(run_num):
        file_name = folder + "_run" + str(i) + "/log.csv"
        data = np.genfromtxt(file_name, delimiter=',')
        x = int(i / 5)
        y = i % 5
        axs[x, y].plot(data[:, 0], smooth(data[:, 1], 0.99))

    plt.show()


def plot_single_run(folder):
    data = np.genfromtxt(folder + "/log.csv", delimiter=',')
    plt_num = data.shape[1] - 1
    fig, axs = plt.subplots(1, plt_num, figsize=(30, 8))
    for i in range(plt_num):
        axs[i].plot(data[:, 0], smooth(data[:, i + 1], 0.99))
        axs[i].grid()
    plt.show()


def plot_lava_Ns(ax):
    agent_num = 2
    frames = 4900000
    Ns = [100, 1000, 10000]
    color = ["green", "blue", "orange"]
    for N, c in zip(Ns, color):
        root_folder = "{output_dir}/outputs_lava_N/"
        folder = root_folder + "centerSquare6x6_" + str(
            agent_num) + "a_PPO_ep4_nbatch4_wprior_N" + str(N) + "_gradNoise_pw1.0_pd0.995"
        label = "B=" + str(N)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=False, use_label=True)

    # plot the oracle
    xs = np.array(range(frames)).tolist()
    ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color='red', label="Oracle")
    ax.legend(loc='lower right')
    # ax.set_title("EG_MARL on the gridworld lava environment with 2 agents")
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")

def plot_lava_Ns_1a(ax):
    agent_num = 1
    frames = 2000000
    Ns = [10, 100]
    color = ["blue", "orange"]
    for N, c in zip(Ns, color):
        root_folder = "{output_dir}/centerSquare6x6_1a_0/outputs_lava_suboptimal/"
        folder = root_folder + "centerSquare6x6_" + str(
            agent_num) + "a_0_PPO_ep4_nbatch4_wprior_N" + str(N) + "_gradNoise_pw1.0_pd0.995"
        label = "B=" + str(N)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=True, use_label=True)

    # plot the oracle
    xs = np.array(range(frames)).tolist()
    ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color='red', label="Oracle")
    ax.legend(loc='upper left')
    # ax.set_title("EG_MARL on the gridworld lava environment with 2 agents")
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


def plot_lava(ax, agent_num, mode, pick_top=False, use_label=False, pick_middle=False):
    # model = optimal or suboptimal
    frames = 4900000

    default_palette = sns.color_palette("colorblind", as_cmap=True)

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_{mode}/"

    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = "new-algo"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[0], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw0.02_pd1.0"
    label = "GEG-MARL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995"
    label = "EG-MARL(B=100)"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[2], pick_top, use_label)
    print(f"----------------- agent num: {agent_num}, EG-MARL done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_MAPPO/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4"
    label = "MAPPO"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAPPO done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_{mode}/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = "MAGAIL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[4], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    folder = f"{output_dir}/centerSquare6x6_{agent_num}a/ata/csv_logs/"
    label = "ATA"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[5], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, ATA done")

    xs = np.array(range(frames)).tolist()
    ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color="red", label="Oracle" if use_label else None)
    # ax.legend(loc='lower right')
    if agent_num == 2:
        ax.set_yscale("asym", a=1/8)
        ax.set_yticks([-20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 3:
        ax.set_yscale("asym", a=1/10)
        ax.set_yticks([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 4:
        ax.set_yscale("asym", a=1/16)
        ax.set_yticks([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    ax.set_title(f"{agent_num} agents", style='italic')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")
    # plt.show()


def plot_lava_all(ax, agent_num, mode, pick_top=False, use_label=False, pick_middle=False):
    # model = optimal or suboptimal
    frames = 4900000

    # default_palette = sns.color_palette("colorblind", as_cmap=True)
    default_palette = sns.color_palette("Paired").as_hex()
    if mode == "optimal":
        s = 0
        suffix = ""
    else:
        s = -1
        suffix = "*"

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_{mode}/"

    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = f"PegMARL{suffix}"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1+s], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    # folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw0.02_pd1.0"
    # label = f"GEG-MARL{suffix}"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3+s], pick_top, use_label, pick_middle)
    # print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    # folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995"
    # label = f"EG-MARL(B=100){suffix}"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[5+s], pick_top, use_label)
    # print(f"----------------- agent num: {agent_num}, EG-MARL done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_{mode}/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = f"MAGAIL{suffix}"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[7+s], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    if mode == "optimal":
        folder = f"{output_dir}/centerSquare6x6_{agent_num}a/ata/csv_logs/"
        label = "ATA"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[9], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, ATA done")

        root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_MAPPO/"
        folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4"
        label = "MAPPO"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[11], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, MAPPO done")

        xs = np.array(range(frames)).tolist()
        ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color="grey", label="Oracle" if use_label else None)

    # ax.legend(loc='lower right')
    if agent_num == 2:
        ax.set_yscale("asym", a=1/8)
        ax.set_yticks([-20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 3:
        ax.set_yscale("asym", a=1/10)
        ax.set_yticks([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 4:
        ax.set_yscale("asym", a=1/16)
        ax.set_yticks([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    ax.set_title(f"{agent_num} agents", style='italic')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")
    # plt.show()


def plot_lava_all2(ax, agent_num, mode, pick_top=False, use_label=False, pick_middle=False):
    # model = optimal or suboptimal
    frames = 4900000

    # default_palette = sns.color_palette("colorblind", as_cmap=True)
    default_palette = sns.color_palette("Paired").as_hex()

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_optimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = f"PegMARL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_suboptimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = f"PegMARL*"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[0], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_optimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw0.02_pd1.0"
    label = f"DM2"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_suboptimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw0.02_pd1.0"
    label = f"DM2*"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[2], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    # folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995"
    # label = f"EG-MARL(B=100){suffix}"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[5+s], pick_top, use_label)
    # print(f"----------------- agent num: {agent_num}, EG-MARL done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_optimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = f"MAGAIL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[7], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_suboptimal/"
    folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = f"MAGAIL*"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[6], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    if mode == "optimal":
        folder = f"{output_dir}/centerSquare6x6_{agent_num}a/ata/csv_logs/"
        label = "ATA"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[9], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, ATA done")

        root_folder = f"{output_dir}/centerSquare6x6_{agent_num}a/outputs_lava_MAPPO/"
        folder = root_folder + "centerSquare6x6_" + str(agent_num) + "a_PPO_ep4_nbatch4"
        label = "MAPPO"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[11], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, MAPPO done")

        xs = np.array(range(frames)).tolist()
        ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color="red", label="Oracle" if use_label else None)

    # ax.legend(loc='lower right')
    if agent_num == 2:
        ax.set_yscale("asym", a=1/8)
        ax.set_yticks([-20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 3:
        ax.set_yscale("asym", a=1/10)
        ax.set_yticks([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    elif agent_num == 4:
        ax.set_yscale("asym", a=1/16)
        ax.set_yticks([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-50, -40, -30, -20, -10, 0, 2, 4, 6, 8, 10])
    ax.set_title(f"{agent_num} agents", style='italic')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")
    # plt.show()


def plot_appleDoor(env_type, ax, mode, pick_top=False, use_label=False, pick_middle=False):
    # type = "a" or "b"
    agent_num = 2
    frames = 4900000

    default_palette = sns.color_palette("colorblind", as_cmap=True)
    # default_palette = sns.color_palette("Paired", as_cmap=True)

    root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_{mode}/"

    folder = root_folder + "appleDoor_" + env_type + "_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = "new-algo"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[0], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    folder = root_folder + "appleDoor_" + env_type + "_POfD_ep4_nbatch4_pw0.02_pd1.0"
    label = "GEG-MARL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    folder = root_folder + "appleDoor_" + env_type + "_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995"
    label = "EG-MARL(B=100)"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[2], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, EG-MARL done")

    root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_MAPPO/"
    folder = root_folder + "appleDoor_" + env_type + "_PPO_ep4_nbatch4"
    label = "MAPPO"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAPPO done")

    root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_{mode}/"
    folder = root_folder + "appleDoor_" + env_type + "_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = "MAGAIL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[4], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    folder = f"{output_dir}/appleDoor_{env_type}/ata/csv_logs/seed"
    label = "ATA"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[5], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, ATA done")

    # plot the oracle
    if type == "a":
        best_return = 9.01
    else:  # type = "b"
        best_return = 8.65
    xs = np.array(range(frames)).tolist()
    ax.plot(xs, np.ones(len(xs)) * best_return, '--', color='red', label="Oracle" if use_label else None)
    # ax.legend(loc='lower right')
    if env_type == "b" and mode == "optimal":
        ax.set_yscale("asym", a=0.25)
        ax.set_yticks([-10, -5, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-10, -5, 0, 2, 4, 6, 8, 10])
    elif env_type == "b" and mode == "suboptimal":
        ax.set_yscale("asym", a=0.5)
        ax.set_yticks([-2, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-2, 0, 2, 4, 6, 8, 10])
    else:
        ax.set_yticks([0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([0, 2, 4, 6, 8, 10])
    ax.set_title("case " + env_type, style='italic')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")
    # plt.show()


def plot_appleDoor_all(env_type, ax, mode, pick_top=False, use_label=False, pick_middle=False):
    # type = "a" or "b"
    agent_num = 2
    frames = 4900000

    # default_palette = sns.color_palette("colorblind", as_cmap=True)
    default_palette = sns.color_palette("Paired").as_hex()
    if mode == "optimal":
        s = 0
        suffix = ""
    else:
        s = -1
        suffix = "*"

    root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_{mode}/"

    folder = root_folder + "appleDoor_" + env_type + "_POfD2_ep4_nbatch4_pw0.05_pd1.0"
    label = f"PegMARL{suffix}"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1+s], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, new algo done")

    folder = root_folder + "appleDoor_" + env_type + "_POfD_ep4_nbatch4_pw0.02_pd1.0"
    label = f"DM2{suffix}"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3+s], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, GEG-MARL done")

    # folder = root_folder + "appleDoor_" + env_type + "_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995"
    # label = f"EG-MARL(B=100){suffix}"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[5+s], pick_top, use_label, pick_middle)
    # print(f"----------------- agent num: {agent_num}, EG-MARL done")

    root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_{mode}/"
    folder = root_folder + "appleDoor_" + env_type + "_POfD_ep4_nbatch4_pw1.0_pd1.0"
    label = f"MAGAIL{suffix}"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[7+s], pick_top, use_label, pick_middle)
    print(f"----------------- agent num: {agent_num}, MAGAIL done")

    if mode == "optimal":
        folder = f"{output_dir}/appleDoor_{env_type}/ata/csv_logs/seed"
        label = "ATA"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[9], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, ATA done")

        root_folder = f"{output_dir}/appleDoor_{env_type}/outputs_appleDoor_MAPPO/"
        folder = root_folder + "appleDoor_" + env_type + "_PPO_ep4_nbatch4"
        label = "MAPPO"
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[11], pick_top, use_label, pick_middle)
        print(f"----------------- agent num: {agent_num}, MAPPO done")
    else:
        # plot the oracle
        if type == "a":
            best_return = 9.01
        else:  # type = "b"
            best_return = 8.65
        xs = np.array(range(frames)).tolist()
        ax.plot(xs, np.ones(len(xs)) * best_return, '--', color='red', label="Oracle" if use_label else None)

    # ax.legend(loc='lower right')
    if env_type == "b" and mode == "optimal":
        ax.set_yscale("asym", a=0.25)
        ax.set_yticks([-10, -5, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-10, -5, 0, 2, 4, 6, 8, 10])
    elif env_type == "b" and mode == "suboptimal":
        ax.set_yscale("asym", a=0.5)
        ax.set_yticks([-2, 0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([-2, 0, 2, 4, 6, 8, 10])
    else:
        ax.set_yticks([0, 2, 4, 6, 8, 10])
        ax.set_yticklabels([0, 2, 4, 6, 8, 10])
    if env_type == "a":
        ax.set_title("easy", style='italic')
    else:
        ax.set_title("hard", style='italic')
    # ax.set_title("case " + env_type, style='italic')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


def plot_mpe(ax):
    agent_num = 2

    # folder = "outputs/outputs_mpeSimpleSpread_2a2g1w_fixedmap/optimal_2a2g_1221_wall_noise/mpeMidSparse_fixedMap_simple_spread_2a_reachAll_POfD_ep4_nbatch4_pw0.1_pd1_seed1"
    # label = "GEG-MARL"
    # file_name = folder + "/log.csv"
    # data = np.genfromtxt(file_name, delimiter=',')
    # frames = data[:, 1]
    # for i in range(len(frames) - 1):
    #     frames[i + 1] += frames[i]
    # rewards = data[:, 2: 2 + agent_num].mean(axis=1)
    # ax.plot(frames, smooth(rewards, 0.99), color="orange", label=label)
    #
    # folder = "outputs/outputs_mpeSimpleSpread_2a2g1w_fixedmap/MAPPO/mpeMidSparse_fixedMap_simple_spread_2a_reachAll_PPO_ep4_nbatch4_seed1"
    # label = "MAPPO"
    # file_name = folder + "/log.csv"
    # data = np.genfromtxt(file_name, delimiter=',')
    # rewards = data[:, 2: 2 + agent_num].mean(axis=1)
    # frames = data[:, 1]
    # for i in range(len(frames) - 1):
    #     frames[i + 1] += frames[i]
    # ax.plot(frames, smooth(rewards, 0.99), color="blue", label=label)

    default_palette = sns.color_palette("Paired").as_hex()

    # frames = 19000000
    frames = 12000000
    root_folder = f"{output_dir}/mpe/outputs_mpeSimpleSpread_2a2g1w_fixedmap/"
    folder = root_folder + "optimal_2a2g_1221_wall_noise_new/mpeMidSparse_newR_fixedMap_simple_spread_2a_reachAll_POfD2_ep4_nbatch4_pw0.2_pd1.0_seed"
    label = "PegMARL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1], pick_top=False, use_label=True, pick_middle=False)

    folder = root_folder + "optimal_2a2g_1221_wall_noise_new/mpeMidSparse_newR_fixedMap_simple_spread_2a_reachAll_PPO_ep4_nbatch4_wprior_N100_gradNoise_pw1.0_pd0.995_seed"
    label = "Count-based"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top=False, use_label=True, pick_middle=False)

    # folder = root_folder + "optimal_2a2g_1221_wall_noise/mpeMidSparse_newR_fixedMap_simple_spread_2a_reachAll_POfD_ep4_nbatch4_pw0.1_pd1.0_seed"
    # label = "DM2"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top=False, use_label=True, pick_middle=False)
    
    # folder = root_folder + "MAPPO/mpeMidSparse_newR_fixedMap_simple_spread_2a_reachAll_PPO_ep4_nbatch4_seed"
    # label = "MAPPO"
    # get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[11], pick_top=False, use_label=True, pick_middle=False)
    
    ax.legend(loc='lower right')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


def plot_lava2a_joint_demon(ax):
    agent_num = 2
    frames = 4900000
    default_palette = sns.color_palette("Paired").as_hex()

    folder = f"{output_dir}/centerSquare6x6_2a/outputs_lava_optimal_discrimFull_coTrained/centerSquare6x6_2a_POfD2_ep4_nbatch4_pw0.05_pd1.0_seed"
    label = "PegMARL"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[1], pick_top=True, use_label=True, pick_middle=False)

    folder = f"{output_dir}/centerSquare6x6_2a/outputs_lava_optimal_discrimFull/centerSquare6x6_2a_POfD2_ep4_nbatch4_pw0.05_pd1.0_seed"
    label = "PegMARL(diff)"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[0], pick_top=True, use_label=True, pick_middle=False)

    folder = f"{output_dir}/centerSquare6x6_2a/outputs_lava_optimal_discrimFull_coTrained/centerSquare6x6_2a_POfD_ep4_nbatch4_pw0.02_pd1.0_seed"
    label = "DM2"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[3], pick_top=True, use_label=True, pick_middle=False)

    folder = f"{output_dir}/centerSquare6x6_2a/outputs_lava_optimal_discrimFull/centerSquare6x6_2a_POfD_ep4_nbatch4_pw0.02_pd1.0_seed"
    label = "DM2(diff)"
    get_multi_runs_mean_var(ax, folder, agent_num, frames, label, default_palette[2], pick_top=True, use_label=True, pick_middle=False)

    ax.legend(loc='lower right')
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


def plot_swap_pweight_ablation(ax):
    agent_num = 2
    frames = 2000000
    ps = [0.01, 0.02, 0.03, 0.05, 0.1]
    color = ["blue", "orange", "green", "red", "purple"]
    for pweight, c in zip(ps, color):
        root_folder = "{output_dir}/swap_2a/outputs_swap_optimal/"
        folder = root_folder + "swap_" + str(
            agent_num) + "a_POfD2_ep4_nbatch4_pw" + str(pweight) + "_pd1.0_"
        print(folder)
        label = "New algo, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=False, use_label=True)
    
    ps = [0.02, 0.05]
    color = ["yellow", "brown"]
    for pweight, c in zip(ps, color):
        root_folder = "{output_dir}/swap_2a/outputs_swap_optimal/"
        folder = root_folder + "swap_" + str(
            agent_num) + "a_POfD_ep4_nbatch4_pw" + str(pweight) + "_pd1.0_"
        print(folder)
        label = "GEG-MARL, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=False, use_label=True)

    # plot the oracle
    # xs = np.array(range(frames)).tolist()
    # ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color='red', label="Oracle")
    ax.legend(loc='upper left')
    # ax.set_title("EG_MARL on the gridworld lava environment with 2 agents")
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")



def plot_lava4a_pweight_ablation(ax):
    agent_num = 4
    map_name = "centerSquare6x6"
    frames = 5000000
    ps = [0.02, 0.05, 0.1]
    color = ["blue", "orange", "green"]
    for pweight, c in zip(ps, color):
        root_folder = f"{output_dir}/{map_name}_{agent_num}a/outputs_lava_optimal/"
        folder = f"{root_folder}{map_name}_{agent_num}a_POfD2_ep4_nbatch4_pw{pweight}_pd1.0_"
        print(folder)
        label = "New algo, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=True, use_label=True)
    
    ps = [0.02, 0.05]
    color = ["yellow", "brown"]
    for pweight, c in zip(ps, color):
        root_folder = f"{output_dir}/{map_name}_{agent_num}a/outputs_lava_optimal/"
        folder = f"{root_folder}{map_name}_{agent_num}a_POfD_ep4_nbatch4_pw{pweight}_pd1.0_"
        print(folder)
        label = "GEG-MARL, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=True, use_label=True)

    # plot the oracle
    # xs = np.array(range(frames)).tolist()
    # ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color='red', label="Oracle")
    ax.legend(loc='upper left')
    # ax.set_title("EG_MARL on the gridworld lava environment with 2 agents")
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


def plot_appleDoorb_pweight_ablation(ax):
    agent_num = 2
    map_name = "appleDoor_b"
    frames = 5000000
    ps = [0.02, 0.05, 0.1]
    color = ["blue", "orange", "green"]
    for pweight, c in zip(ps, color):
        root_folder = f"{output_dir}/{map_name}/outputs_appleDoor_suboptimal/"
        folder = f"{root_folder}{map_name}_POfD2_ep4_nbatch4_pw{pweight}_pd1.0_"
        print(folder)
        label = "New algo, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=False, use_label=True)
    
    ps = [0.02, 0.05]
    color = ["yellow", "brown"]
    for pweight, c in zip(ps, color):
        root_folder = f"{output_dir}/{map_name}/outputs_appleDoor_suboptimal/"
        folder = f"{root_folder}{map_name}_POfD_ep4_nbatch4_pw{pweight}_pd1.0_"
        print(folder)
        label = "GEG-MARL, $\eta$=" + str(pweight)
        get_multi_runs_mean_var(ax, folder, agent_num, frames, label, c, pick_top=False, use_label=True)

    # plot the oracle
    # xs = np.array(range(frames)).tolist()
    # ax.plot(xs, np.ones(len(xs)) * 8.38, '--', color='red', label="Oracle")
    ax.legend(loc='upper left')
    # ax.set_title("EG_MARL on the gridworld lava environment with 2 agents")
    ax.set_xlabel("Number of Timesteps")
    ax.set_ylabel("Episode Rewards")


if __name__ == "__main__":

    fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    plot_mpe(ax)
    fig.tight_layout()
    plt.savefig("mpe2.pdf")
    plt.close()

    # fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    # plot_lava2a_joint_demon(ax)
    # fig.tight_layout()
    # plt.savefig("lava2a_joint_pickTop.pdf")
    # plt.close()
    

    # fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    # plot_appleDoorb_pweight_ablation(ax)
    # fig.tight_layout()
    # plt.savefig("appleDoor_b_ablation.pdf")
    # plt.close()

    # fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    # plot_lava4a_pweight_ablation(ax)
    # fig.tight_layout()
    # plt.savefig("lava_4a_ablation_optimal.pdf")
    # plt.close()

    # fig, ax = plt.subplots(1, 1, figsize=(9, 6))
    # # plot_lava_Ns_1a(ax)
    # plot_swap_pweight_ablation(ax)
    # fig.tight_layout()
    # plt.savefig("swap_ablation.pdf")
    # plt.close()

    # ns = [2, 3, 4]
    # modes = ["suboptimal"]  #"optimal", 
    # pick_top_5s = [True, False]
    # pick_middles = [False]
    # for mode in modes:
    #     for pick_top_5 in pick_top_5s:
    #         for pick_middle in pick_middles:
    #             num = len(ns)
    #             fig, axs = plt.subplots(1, num, figsize=(9*num, 7.8))
    #             print(f"mode: {mode}, pick_top_5: {pick_top_5}, pick_middle: {pick_middle}")
    #             for i in range(num):
    #                 use_label = True if i == 0 else False
    #                 plot_lava(axs[i], agent_num=ns[i], mode=mode, pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)

    #             fig_name = f"lava_{mode}"
    #             if pick_top_5:
    #                 fig_name += "_top5"
    #             if pick_middle:
    #                 fig_name += "_IQM"

    #             fig.legend(loc='upper center', ncol=7)
    #             fig.tight_layout()
    #             fig.subplots_adjust(top=0.8)

    #             plt.savefig(f"{fig_name}.pdf")
    #             plt.close()

    # ns = [2, 3, 4]
    # modes = ["suboptimal"]  #"optimal", 
    # pick_top_5s = [True]
    # pick_middles = [False]
    # for pick_top_5 in pick_top_5s:
    #     for pick_middle in pick_middles:
    #         num = len(ns)
    #         fig, axs = plt.subplots(1, num, figsize=(9*num, 8))
    #         print(f"pick_top_5: {pick_top_5}, pick_middle: {pick_middle}")
    #         for i in range(num):
    #             use_label = True if i == 0 else False
    #             plot_lava_all2(axs[i], agent_num=ns[i], mode="optimal", pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)
    #             # plot_lava_all(axs[i], agent_num=ns[i], mode="suboptimal", pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)

    #         fig_name = f"lava_all2"
    #         if pick_top_5:
    #             fig_name += "_top5"
    #         if pick_middle:
    #             fig_name += "_IQM"
            
    #         handles, labels = axs[0].get_legend_handles_labels()
    #         # fig.legend(flip(handles, 7), flip(labels, 7), loc='upper center', ncol=7)
    #         fig.legend(loc='upper center', ncol=9)
    #         fig.tight_layout()
    #         fig.subplots_adjust(top=0.78)

    #         plt.savefig(f"{fig_name}.pdf")
    #         plt.close()

    # env_types = ["a", "b"]
    # modes = ["optimal", "suboptimal"]
    # pick_top_5s = [False, True]
    # pick_middles = [False]
    # for mode in modes:
    #     for pick_top_5 in pick_top_5s:
    #         for pick_middle in pick_middles:
    #             num = len(env_types)
    #             fig, axs = plt.subplots(1, num, figsize=(9*num, 6.5))
    #             print(f"mode: {mode}, pick_top_5: {pick_top_5}, pick_middle: {pick_middle}")
    #             for i in range(num):
    #                 use_label = True if i == 0 else False
    #                 plot_appleDoor(env_types[i], axs[i], mode, pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)

    #             fig_name = f"appleDoor_{mode}"
    #             if pick_top_5:
    #                 fig_name += "_top5"
    #             if pick_middle:
    #                 fig_name += "_IQM"

    #             # fig.legend(loc='upper center', ncol=3)
    #             fig.tight_layout()
    #             # fig.subplots_adjust(top=0.78)

    #             plt.savefig(f"{fig_name}.pdf")
    #             plt.close()


    # env_types = ["a", "b"]
    # pick_top_5s = [False, True]
    # modes = ["optimal", "suboptimal"]
    # pick_middles = [False]
    # for pick_top_5 in pick_top_5s:
    #     for pick_middle in pick_middles:
    #         num = len(env_types)
    #         fig, axs = plt.subplots(1, num, figsize=(9*num, 7))
    #         print(f"pick_top_5: {pick_top_5}, pick_middle: {pick_middle}")
    #         for i in range(num):
    #             use_label = True if i == 0 else False
    #             plot_appleDoor_all(env_types[i], axs[i], "optimal", pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)
    #             plot_appleDoor_all(env_types[i], axs[i], "suboptimal", pick_top=pick_top_5, use_label=use_label, pick_middle=pick_middle)

    #         fig_name = f"appleDoor_all"
    #         if pick_top_5:
    #             fig_name += "_top5"
    #         if pick_middle:
    #             fig_name += "_IQM"
            
    #         handles, labels = axs[0].get_legend_handles_labels()
    #         ncols = 5
    #         fig.legend(flip(handles, ncols), flip(labels, ncols), loc='upper center', ncol=ncols)
    #         fig.tight_layout()
    #         fig.subplots_adjust(top=0.78)

    #         plt.savefig(f"{fig_name}.pdf")
    #         plt.close()


    # fig, ax = plt.subplots()
    
    # ax.plot([-2, 0, 5], [-5, 1, 0])
    # ax.set_yscale("asym", a=0.1)
    
    # ax.annotate("negative axis", xy=(.25, 0), xytext=(0, -30),
    #             xycoords="axes fraction", textcoords="offset points", ha="center")
    # ax.annotate("positive axis", xy=(.75, 0), xytext=(0, -30),
    #             xycoords="axes fraction", textcoords="offset points", ha="center")
    # plt.show()

# lava: 8.38
# apple door a: 11 steps (9.01)
# apple door b: 15 steps (8.65)