import argparse
import pickle

import matplotlib.pyplot as plt
import numpy as np


def get_args():
    parser = argparse.ArgumentParser(description="argparse script")
    parser.add_argument(
        "-regime",
        "--regime",
        type=str,
        choices=["Stochastic", "StochasticWithCorruption", "Adversary"],
        default="Stochastic",
        help="Which regime to consider.",
    )
    parser.add_argument(
        "-time_horizon",
        "--time_horizon",
        type=int,
        default=2000,
        help="The value of sigma.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    cost_min, cost_max = 0.1, 0.5
    file_name = "results/regret_result_{}_{}_{}_{}.txt".format(
        args.regime,
        cost_min,
        cost_max,
        args.time_horizon,
    )

    algorithms_list = [
        "GenLBINFV (LS)",
        "GenLBINFV (GD)",
        "LBINFV (LS)",
        "LBINFV (GD)",
        "GenCTS",
        "CTS",
    ]

    with open(file_name, "rb") as f:
        data = pickle.load(f)
        data_length = len(data)
        print(data_length)
        for algorithm in algorithms_list:
            time_horizon = len(data[0][algorithm])
            x = np.arange(time_horizon)
            y = np.zeros((time_horizon, data_length))
            for num_exp in range(data_length):
                data_algorithm = data[num_exp][algorithm]
                for t in range(time_horizon):
                    y[t][num_exp] = data_algorithm[t]
            mean = np.mean(y, axis=1)
            std = np.std(y, axis=1)

            ax = plt.subplot(111)

            (line,) = ax.plot(x, mean, label=algorithm)
            color = line.get_color()
            ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
            ax.legend()
        ax.set_title(args.regime)
        plt.savefig(
            "results/regret_result_{}_{}_{}_{}.pdf".format(
                args.regime,
                cost_min,
                cost_max,
                args.time_horizon,
            )
        )
        plt.show()
