import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from plot_utils import *


# config['xlim'] = (2, 20.)
# config['ylim'] = (70., 105.)

config['ylim'] = (40., 105.)
config['xlim'] = (1., 4.)

# config['ylim'] = (0., 105.)
# config['xlim'] = (2., 32.)
# config['xlim'] = (4., 32.)

config['xlabel'] = '$log$ Data Size'
# config['xlabel'] = 'Horizon'
# config['xlabel'] = 'Num Trajectories'
# config['xlabel'] = 'Feature Dim'
config['ylabel'] = 'Relative Performance'
config['smooth_range'] = 1

# result = np.load("./results_num_chain_2.npy")
# result = np.load("./results_horizon.npy")
result = np.load("./results.npy")
# result = np.load("./results_feature_dim.npy")


def data_process(xs, paths):
    # paths = np.array(paths)
    mean = np.mean(paths, 1)
    std = np.std(paths, 1)
    return xs, mean, std, paths.transpose()


x = np.log10([10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000])
# x = np.array([2, 4, 8, 12, 16, 20, 24, 28, 32])
# x = np.array([4, 8, 12, 16, 20, 24, 28, 32])
# x = np.array([2, 4, 6, 8, 10, 12, 14, 16, 18, 20])
# average_result = np.average(result, axis=1)

normalized_result_1 = result[:, :, 0] / result[:, :, 2] * 100
normalized_result_2 = result[:, :, 1] / result[:, :, 2] * 100
y1 = data_process(x, normalized_result_1)
y2 = data_process(x, normalized_result_2)

yy1 = np.average(normalized_result_1, axis=1)
yy2 = np.average(normalized_result_2, axis=1)
legends = ["Random Query", "Random Weighted Query"]

# y1 = average_result[:, 0]
# y2 = average_result[:, 1]
# optimal = np.average(average_result[:, 2])
# plt.plot(x, yy1)
# plt.plot(x, yy2)

# plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

plot_all([y1, y2], legends, 1)

plt.title(f'Performance in Chain MDP', size=30)
legend()
fig = matplotlib.pyplot.gcf()
fig.set_size_inches(8, 7)
plt.savefig(f"./plot/random_weight.pdf", format="pdf")
# plt.savefig(f"./plot/random_weight_num_chain.pdf", format="pdf")
# plt.savefig(f"./plot/random_weight_horizon.pdf", format="pdf")
# plt.savefig(f"./plot/random_weight_feature_dim.pdf", format="pdf")

plt.show()
