# need to change this
rootdir = '/content/drive/MyDrive/colab/cmnist/'

#@title Imports
import os
import copy
import math
import pickle
import dill
import random
from math import factorial
from itertools import chain, combinations

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D

#@title Set Seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

#@title MLP
class MLP(nn.Module):
    def __init__(self, num_hid):
        super().__init__()
        self.linear1 = nn.Linear(28 * 28, num_hid)
        self.linear2 = nn.Linear(num_hid, 10)

    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def test_wrapper(test_loader):
    def test(model):
        correct = 0
        loss = 0
        total = 0
        device = next(model.parameters()).device
        model.eval()
        with torch.no_grad():
            for (x, y) in test_loader:
                x, y = x.to(device), y.to(device)
                o = model(x)
                loss += nn.CrossEntropyLoss(reduction='sum')(o, y).item()
                p = o.softmax(dim=1)
                p = p.argmax(dim=1).to(y.dtype)
                p = (p == y).float()
                correct += p.sum()
                total += len(y)
        acc = (correct / total).item()
        loss = loss / total
        return (acc, loss)
    return test

transform = transforms.ToTensor()
train_set = torchvision.datasets.MNIST(rootdir + 'data/', train=True, transform=transform, download=True)
test_set = torchvision.datasets.MNIST(rootdir + 'data/', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
test = test_wrapper(test_loader)

# Reduce the full training set since MLP is too powerful for MNIST
if os.path.exists(rootdir + 'train_set.pkl'):
    with open(rootdir + 'train_set.pkl', 'rb') as f:
        train_set = pickle.load(f)
else:
    indices = list(range(len(train_set)))
    random.shuffle(indices)
    subset_indices = indices[:20000]
    train_set = torch.utils.data.Subset(train_set, subset_indices)
    with open(rootdir + 'train_set.pkl', 'wb') as f:
        pickle.dump(train_set, f)

def train(train_set, model_save_path=None):
    batch_size = 128
    lr = 4e-3
    epochs = 10
    device = get_device()
    criterion = nn.CrossEntropyLoss()
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    mlp = MLP(16)
    mlp.to(device)
    optimizer = optim.SGD(mlp.parameters(), lr=lr)
    best_acc = 0
    best_model = copy.deepcopy(mlp)
    for epoch in range(epochs):
        mlp.train()
        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            o = mlp(x)
            loss = criterion(o, y)
            loss.backward()
            optimizer.step()
        (test_acc, _) = test(mlp)
        # print(f"Epoch {epoch}: Test accuracy {test_acc:.4f}")
        if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(mlp)
            if model_save_path is not None:
                torch.save(mlp.state_dict(), model_save_path)
    return best_model

#@title Valuation
def split_by_label(labels):
    label_indices = [[] for _ in range(len(labels))]
    for idx, (x, y) in enumerate(train_set):
        label_indices[y].append(idx)
    random.shuffle(labels)
    p_labels = [labels[:3], labels[3:6], labels[6:]]
    p_indices = list(map(lambda labels: list(chain.from_iterable(map(lambda y: label_indices[y], labels))), p_labels))
    with open(rootdir + 'party_partition.pkl', 'wb') as f:
        pickle.dump((p_indices, p_labels), f)
    return p_indices

if os.path.exists(rootdir + 'party_partition.pkl'):
    with open(rootdir + 'party_partition.pkl', 'rb') as f:
        p_indices, p_labels = pickle.load(f)
else:
    p_indices = split_by_label(list(range(10)))

def is_submodular(v):
    cs = v.keys()
    pairs = [(i, j) for i in cs for j in cs if i != j]
    p = lambda i, j: v[i] + v[j] >= v[i | j] + v[i & j]
    return all(p(*pair) for pair in pairs)

def is_superadditive(v):
    cs = v.keys()
    pairs = [(i, j) for i in cs for j in cs if i & j == set()]
    p = lambda i, j: v[i | j] >= v[i] + v[j]
    return all(p(*pair) for pair in pairs)

def get_v_c(c):
    if len(c) == 0:
        return 0.0
    else:
        D_c_indices = list(chain.from_iterable([p_indices[i] for i in c]))
        D_c = torch.utils.data.Subset(train_set, D_c_indices)
        model = train(D_c)
        (acc, loss) = test(model)
        return acc

def get_v(n):
    cs = [list(j) for i in range(n + 1) for j in combinations(list(range(n)), i)]
    v = {frozenset(c): get_v_c(c) for c in cs}
    assert is_submodular(v), "v is not submodular!"
    return v

def get_v_dual(v):
    N = max(v.keys())
    v_dual = {c: v[N] - v[N - c] for c in v.keys()}
    del v_dual[frozenset()]
    assert is_superadditive(v_dual), "v_dual is not superadditive!"
    return v_dual

#@title Reward Value
def get_shapley_value(i, v):
    phi = 0.0
    n = max(len(s) for s in v.keys())
    for (coalition, value) in v.items():
        if i not in coalition:
            weight = factorial(len(coalition)) * factorial(n - len(coalition) - 1) / factorial(n)
            marginal_contribution = v[coalition.union({i})] - value
            phi += weight * marginal_contribution
    weight = factorial(0) * factorial(n-1) / factorial(n)
    phi += weight * v[frozenset([i])]
    return phi

def get_shapley_values(v):
    n = max(len(s) for s in v.keys())
    shapleys = [0.0] * n
    for i in range(n):
        shapleys[i] = get_shapley_value(i, v)
    return shapleys

def get_dividends(v):
    dividends = {}
    for coalition in v.keys():
        dividend = v[coalition]
        for subset in dividends.keys():
            if subset < coalition:
                dividend -= dividends[subset]
        dividends[coalition] = dividend
    return dividends

def get_time_aware_v(v, times, weight_function):
    weights = weight_function(times)
    dividends = get_dividends(v)
    time_aware_v = {}
    for coalition in v.keys():
        value = 0.0
        for subset in dividends.keys():
            if subset <= coalition:
                if len(subset) == 1:
                    value += dividends[subset]
                else:
                    min_weight = min([w for (i, w) in enumerate(weights) if i in subset])
                    value += min_weight * dividends[subset]
        time_aware_v[coalition] = value
    return time_aware_v

def get_N_tau(times, tau):
    n = len(times)
    N_tau = [i for i in range(n) if times[i] <= tau]
    return frozenset(set(N_tau))

def get_N_tau_complement(times, tau):
    n = len(times)
    N_tau = [i for i in range(n) if times[i] > tau]
    return frozenset(set(N_tau))

def get_v_tau(v, N_tau):
    v_tau = {}
    for coalition in v.keys():
        if coalition <= N_tau:
            v_tau[coalition] = v[coalition]
    return v_tau

def get_reward_cumulation(v, times, beta):
    n = max(len(s) for s in v.keys())
    rewards = [0.0] * n
    w_sum = 0.0
    T = int(max(times))
    for tau in range(T+1):
        w_tau = beta**tau
        w_sum += w_tau
        N_tau = get_N_tau(times, tau)
        N_tau_complement = get_N_tau_complement(times, tau)
        v_tau = get_v_tau(v, N_tau)
        for i in N_tau:
            rewards[i] += w_tau * get_shapley_value(i, v_tau)
        for i in N_tau_complement:
            rewards[i] += w_tau * v[frozenset([i])]
    rewards = [reward / w_sum for reward in rewards]
    return rewards

def get_scaling_factor(v):
    v_N = max(v.values())
    v_max = max(get_shapley_values(v))
    return v_N / v_max

def scale_rewards(scaling_factor, rewards):
    return [reward * scaling_factor for reward in rewards]

def get_reward_gamma(v, times, gamma):
    assert max(len(s) for s in v.keys()) == len(times)
    weight_function_gamma = lambda times, gamma: [math.exp(-time * gamma) for time in times]
    weight_function = lambda times: weight_function_gamma(times, gamma)
    time_aware_v = get_time_aware_v(v, times, weight_function)
    rewards = get_shapley_values(time_aware_v)
    scaling_factor = get_scaling_factor(v)
    return scale_rewards(scaling_factor, rewards)

def get_reward_beta(v, times, beta):
    assert max(len(s) for s in v.keys()) == len(times)
    rewards = get_reward_cumulation(v, times, beta)
    scaling_factor = get_scaling_factor(v)
    return scale_rewards(scaling_factor, rewards)

def generate_reward_value(filename, v, gamma_list, beta_list):
    n = max(len(s) for s in v.keys())
    times = np.zeros(n)
    rows = []
    for gamma in gamma_list:
        for t in range(5):
            times[0] = t
            rewards = get_reward_gamma(v, times, gamma=gamma)
            rows.append([t] + rewards + [gamma, np.nan])
    for beta in beta_list:
        for t in range(5):
            times[0] = t
            rewards = get_reward_beta(v, times, beta=beta)
            rows.append([t] + rewards + [np.nan, beta])
    df = pd.DataFrame(rows, columns=['$t$', '$r_1$', '$r_2$', '$r_3$', r'$\gamma$', r'$\beta$'])
    df['$r_1-r_2$'] = df['$r_1$'] - df['$r_2$']
    df['$r_1-r_3$'] = df['$r_1$'] - df['$r_3$']
    return df

# v = get_v(3)
# v_dual = get_v_dual(v)
# with open(rootdir + 'v.pkl', 'wb') as f:
#     pickle.dump(v, f)
# with open(rootdir + 'v_dual.pkl', 'wb') as f:
#     pickle.dump(v_dual, f)
# train(train_set, rootdir + 'model_all.pth')

with open(rootdir + 'v.pkl', 'rb') as f:
    v = pickle.load(f)
with open(rootdir + 'v_dual.pkl', 'rb') as f:
    v_dual = pickle.load(f)
with open(rootdir + 'party_partition.pkl', 'rb') as f:
    p_indices, p_labels = pickle.load(f)
model_all = MLP(16)
model_all.load_state_dict(torch.load(rootdir + 'model_all.pth', map_location=get_device()))

print(f'{v_dual=}')
print(f'{p_labels=}')
print(f'{test(model_all)=}')

gamma_list = [0, 1, 10]
beta_list = [0, 1, 10]
df = generate_reward_value(None, v_dual, gamma_list, beta_list)

#@title Reward Value Plot
def plot_value(df, v, gamma_list, beta_list, plotdir, format='png'):
    # plotdir = os.path.join('./plots', plotdir)
    os.makedirs(plotdir, exist_ok=True)
    figsize=(8, 6)
    df.rename(columns={'$r_1$': '$r_1^*$',
                       '$r_2$': '$r_2^*$',
                       '$r_3$': '$r_3^*$',
                       '$r_1-r_2$': '$r_1^*-r_2^*$',
                       '$r_1-r_3$': '$r_1^*-r_3^*$'},
              inplace=True)
    v_self = [v[frozenset({0})], v[frozenset({1})], v[frozenset({2})]]
    n = max(len(s) for s in v.keys())
    v_N = v[frozenset(range(n))]
    methods = [r'$\gamma$', r'$\beta$']
    hypers = [gamma_list, beta_list]
    colormaps = [plt.cm.Blues, plt.cm.Oranges]
    for (method, hyper) in zip(methods, hypers):
        fig, ax = plt.subplots(figsize=figsize)
        plot_abs(ax, df, v_self, v_N, values=hyper, method=method)
        plt.tight_layout()
        method_str = method.replace('\\', '').replace('$', '')
        plotname = f'value_{method_str}.{format}'
        plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
        plt.close(fig)
    for diff in ['$r_1^*-r_2^*$', '$r_1^*-r_3^*$']:
        fig, ax = plt.subplots(figsize=figsize)
        diff_str = diff.replace('$', '').replace('^', '').replace('*', '')
        plotname = f'value_{diff_str}.{format}'
        for (method, hyper, colormap) in zip(methods, hypers, colormaps):
            plot_diff(ax, df, values=hyper, method=method, y_label=diff, colormap=colormap)
        plt.tight_layout()
        plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
        plt.close(fig)

def plot_abs(ax, df, v_self, v_N, values, method):
    parties = {'$r_1^*$', '$r_2^*$', '$r_3^*$'}
    colormaps = [plt.cm.Blues, plt.cm.Oranges, plt.cm.Greens]
    linestyles = ['-', 'dotted', 'dashed']
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']
    markersize = 13
    linewidth = 4.0

    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]
    vmin = min(map_values) - 0.5
    vmax = 1
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    scalar_map = {}
    for (party, colormap) in zip(parties, colormaps):
        scalar_map[party] = cm.ScalarMappable(norm=color_norm, cmap=colormap)

    for (value, map_value, marker) in zip(values, map_values, markers):
        df_new = df[df[method] == value].copy()
        for (party, linestyle) in zip(parties, linestyles):
            ax.plot(df_new['$t$'], df_new[party],
                    color=scalar_map[party].to_rgba(map_value),
                    markersize=markersize,
                    linewidth=linewidth,
                    marker=marker,
                    linestyle=linestyle)

    legends = []
    for (value, marker) in zip(values, markers):
        legends.append(Line2D([], [],
                              color='black',
                              linewidth=linewidth,
                              markersize=markersize,
                              marker=marker,
                              label=f'{method}={value}'))

    label_color_value = map_values[1] # same as the second hyperpara
    for (party, linestyle) in zip(parties, linestyles):
        legends.append(Line2D([], [],
                              color=scalar_map[party].to_rgba(label_color_value),
                              linewidth=linewidth,
                              linestyle=linestyle,
                              label=party))


    ax.legend(handles=legends, ncol=2, prop={'size': 25}, loc='best')
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_ylabel('Reward', fontsize=40)
    ax.set_xticks(df_new['$t$'].unique())
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)


    # horizontal lines
    plt.axhline(y=v_self[0], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_self[1], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_self[2], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_N, color='grey', alpha=0.3, linewidth=2)

    yticks_ax2 = v_self + [v_N]
    ax2 = ax.twinx()
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticks(yticks_ax2)
    ax2.set_yticklabels(['$v_1$', '$v_2$', '$v_3$', '$v_N$'], fontsize=25)
    ax2.tick_params(axis='y', labelsize=30)
    parties = ['$r_1^*$', '$r_2^*$', '$r_3^*$']
    for (v_self_label, party) in zip(ax2.get_yticklabels(), parties):
        v_self_label.set_color(scalar_map[party].to_rgba(label_color_value))

def plot_diff(ax, df, values, method, y_label, colormap):
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']

    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]
    vmin = min(map_values) - 0.5
    vmax = 1
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    scalar_map = cm.ScalarMappable(norm=color_norm, cmap=colormap)

    for (value, map_value, marker) in zip(values, map_values, markers):
        df_new = df[df[method] == value].copy()
        ax.plot(df_new['$t$'], df_new[y_label],
                color=scalar_map.to_rgba(map_value),
                linewidth=4.0,
                markersize=13,
                marker=marker,
                label=f'{method}={value}')

    ax.legend(ncol=2, prop={'size': 25}, loc='best')
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_xticks(df_new['$t$'].unique())
    ax.set_ylabel(y_label, fontsize=40)
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)

    plt.axhline(y=0, color='grey', alpha=0.3, linewidth=2, linestyle='dashed')

# This also renames the columns in df
plot_value(df, v_dual, gamma_list, beta_list, rootdir + 'plots/')

#@title Reward Realization
def normalize_acc(acc):
    return acc - 0.1

def get_dual_accuracy(train_set, non_i_indices, acc_all):
    def dual_accuracy(pos):
        if pos < 0 or pos > len(non_i_indices):
            raise IndexError("R_i is not a subset of (D\D_i)!")
        if pos == len(non_i_indices): # use this pos to mark the case when R_i = D\D_i
            return acc_all
        indices = non_i_indices[pos:]
        subset = torch.utils.data.Subset(train_set, indices)
        model = train(subset)
        (acc, loss) = test(model)
        return acc_all - normalize_acc(acc) # zero normalization for random models
    return dual_accuracy

tol_fn = lambda val: 5e-2 if val >= 0.70 else 35e-3 if val >= 0.53 else 25e-3

def get_reward_index(reward_value, dual_accuracy, start_pos, step=100, tol_fn=tol_fn):
    total = 0
    tol = tol_fn(reward_value)
    pos = start_pos
    while True:
        total += step
        pos = pos - step
        val = dual_accuracy(pos)
        print(f'get_reward_index: {pos=} target={reward_value} {total=} {val=}')
        if math.isclose(reward_value, val, abs_tol=tol):
            return pos

def get_reward_indices(reward_values, dual_accuracy, start_pos): # reward_values need to be in descending order
    poss = [start_pos]
    current_pos = start_pos
    current_val = reward_values.pop(0)
    for val in reward_values:
        # if val == current_val:
        #     poss.append(current_pos)
        #     continue
        if math.isclose(val, current_val, abs_tol=0.001):
            poss.append(current_pos)
            current_val = val
            continue
        current_pos = get_reward_index(val, dual_accuracy, current_pos)
        current_val = val
        poss.append(current_pos)
    return poss

def realize_reward(train_set, D_i_indices, non_i_indices, pos):
    indices = D_i_indices.copy()
    indices.extend(non_i_indices[:pos])
    subset = torch.utils.data.Subset(train_set, indices)
    model = train(subset)
    (acc, loss) = test(model)
    num_data = len(indices)
    return (num_data, acc, loss)

def get_reward_values_from_df(df, party_index, method, hyper):
    assert method == 'gamma' or method == 'beta', "Unknown method!"
    party = fr'$r_{party_index}^*$'
    method = fr'$\{method}$'
    return df[df[method] == hyper][party].tolist()

# party index in descending order: [3, 2, 1]
def init_realization(df, party_list, p_indices, train_set, acc_all):
    if os.path.exists(rootdir + 'rewards/init.pkl'):
        with open(rootdir + 'rewards/init.pkl', 'rb') as f:
            return dill.load(f)
    D_i_indices_list = []
    non_i_indices_list = []
    dual_accuracy_list = []
    start_val_list = []
    start_pos_list = []

    for party in party_list:
        D_i_indices_list.append(p_indices[party - 1])
        non_i_indices = []
        for (i, indices) in enumerate(p_indices):
            if party - 1 != i:
                non_i_indices.extend(indices)
        random.shuffle(non_i_indices)
        non_i_indices_list.append(non_i_indices)
        dual_accuracy_list.append(get_dual_accuracy(train_set, non_i_indices, acc_all))
        start_val_list.append(get_reward_values_from_df(df, party, 'gamma', 0)[0])
    for i in range(len(start_val_list)):
        start_pos = len(non_i_indices_list[i]) if i == 0 else get_reward_index(
            start_val_list[i],
            dual_accuracy_list[i],
            len(non_i_indices_list[i]),
            step=500,
            tol_fn=lambda val: 25e-3)
        start_pos_list.append(start_pos)
    with open(rootdir + 'rewards/init.pkl', 'wb') as f:
        init = (D_i_indices_list, non_i_indices_list, dual_accuracy_list, start_val_list, start_pos_list)
        dill.dump(init, f)
    return init

def realize_reward_one_party(party, gamma_list, beta_list, df, train_set, init):
    # party_list = [3, 2, 1]
    party_idx = {3: 0, 2: 1, 1: 2}
    i = party_idx[party]
    with open(rootdir + 'rewards/rewards.pkl', 'rb') as f:
        rewards = pickle.load(f)
    (D_i_indices_list, non_i_indices_list, dual_accuracy_list, start_val_list, start_pos_list) = init
    D_i_indices = D_i_indices_list[i]
    non_i_indices = non_i_indices_list[i]
    dual_accuracy = dual_accuracy_list[i]
    start_val = start_val_list[i]
    start_pos = start_pos_list[i]
    print(f'starting party {party}')
    rewards_one_party = {}
    for gamma in gamma_list:
        reward_values = get_reward_values_from_df(df, party, 'gamma', gamma)
        reward_indices = get_reward_indices(reward_values, dual_accuracy, start_pos)
        rewards_one_hyper = [realize_reward(train_set, D_i_indices, non_i_indices, pos) for pos in reward_indices]
        rewards_one_party[f'gamma={gamma}'] = rewards_one_hyper
    print(f'finish gamma for party {party}')
    for beta in beta_list:
        reward_values = get_reward_values_from_df(df, party, 'beta', beta)
        reward_indices = get_reward_indices(reward_values, dual_accuracy, start_pos)
        rewards_one_hyper = [realize_reward(train_set, D_i_indices, non_i_indices, pos) for pos in reward_indices]
        rewards_one_party[f'beta={beta}'] = rewards_one_hyper
    print(f'finish beta for party {party}')
    rewards[party] = rewards_one_party
    with open(rootdir + 'rewards/rewards.pkl', 'wb') as f:
        pickle.dump(rewards, f)

party_list = [3, 2, 1]
acc_all = max(v.values())
# rewards = {party: None for party in party_list}
# with open(rootdir + 'rewards/rewards.pkl', 'wb') as f:
#     pickle.dump(rewards, f)
init = init_realization(df, party_list, p_indices, train_set, acc_all)

realize_reward_one_party(3, gamma_list, beta_list, df, train_set, init)
realize_reward_one_party(2, gamma_list, beta_list, df, train_set, init)
realize_reward_one_party(1, gamma_list, beta_list, df, train_set, init)

#@title Rewards Plot
metric_to_ylabel = {0: '# data points', 1: 'Accuracy', 2: 'Loss'}
metric_to_filename = {0: 'num', 1: 'acc', 2: 'loss'}

def get_plot_data(rewards, method, value, metric, party):
    results = rewards[party][fr'{method}={value}']
    data = [res[metric] for res in results]
    return data

def plot_reward(ax, rewards, values, method, metric):
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']
    parties = [1, 2, 3]
    party_labels = ['Party 1', 'Party 2', 'Party 3']
    linestyles = ['-', 'dotted', 'dashed']
    colormaps = [plt.cm.Blues, plt.cm.Oranges, plt.cm.Greens]
    markersize = 13
    linewidth = 4.0

    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]
    vmin = min(map_values) - 0.5
    vmax = 1
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    scalar_map = {}
    for (party, colormap) in zip(parties, colormaps):
        scalar_map[party] = cm.ScalarMappable(norm=color_norm, cmap=colormap)
    for (value, map_value, marker) in zip(values, map_values, markers):
        # df_new = df[(df[method] == value) & (df['index'] == metric)].copy()
        for (party, colormap, linestyle) in zip(parties, colormaps, linestyles):
            data = get_plot_data(rewards, method, value, metric, party)
            times = list(range(len(data)))
            ax.plot(times, data,
                    color=scalar_map[party].to_rgba(map_value),
                    markersize=markersize,
                    linewidth=linewidth,
                    marker=marker,
                    linestyle=linestyle,
                    )

    party_legends = []
    hyper_legends = []
    for (value, marker) in zip(values, markers):
        hyper_legends.append(Line2D([], [],
                             color='black',
                             linewidth=linewidth,
                             markersize=markersize,
                             marker=marker,
                             label=fr'$\{method}$={value}'))

    label_color_value = map_values[1]
    for (party, linestyle, party_label) in zip(parties, linestyles, party_labels):
        party_legends.append(Line2D([], [],
                              color=scalar_map[party].to_rgba(label_color_value),
                              linewidth=linewidth,
                              linestyle=linestyle,
                              label=party_label))
    legends = []
    for (party_legend, hyper_legend) in zip(party_legends, hyper_legends):
        legends.extend([party_legend, hyper_legend])

    ax.legend(handles=legends, ncol=3, prop={'size': 23.5}, handlelength=1, loc='upper center', bbox_to_anchor=(0.5, -0.19))
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_ylabel(f'{metric_to_ylabel[metric]}', fontsize=40)
    ax.set_xticks(times)
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)

def plot_mnist(rewards, gamma_list, beta_list, plotdir, format='png'):
    os.makedirs(plotdir, exist_ok=True)
    figsize=(9, 7.5)
    methods = ['gamma', 'beta']
    metrics = [0, 1, 2] # 0: num_data, 1: acc, 2: loss
    hypers = [gamma_list, beta_list]
    for (method, hyper) in zip(methods, hypers):
        for metric in metrics:
            fig, ax = plt.subplots(figsize=figsize)
            plot_reward(ax, rewards, method=method, values=hyper, metric=metric)
            plotname = f'{method}_{metric_to_filename[metric]}.{format}'
            plt.tight_layout()
            plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
            plt.close(fig)

with open(rootdir + 'rewards/rewards.pkl', 'rb') as f:
    rewards = pickle.load(f)
plot_mnist(rewards, gamma_list, beta_list, rootdir + 'plots/')