import sys
sys.path.append("../")
import json
# from utils import *
import matplotlib.pyplot as plt
import matplotlib.cm as cm  # Import the color map module

fig, axs = plt.subplots(1,4, figsize=(10, 1.7))

from DataLoader.test_nasbench_201 import *
from LCBench.api import Benchmark
dataset = 'ImageNet'
if dataset == 'ImageNet':
    Bench = ImageNetNasBench201Benchmark(rng=1)
elif dataset == 'Cifar100':
    Bench = Cifar100NasBench201Benchmark(rng=1)
elif dataset == 'Cifar10':
    Bench = Cifar10ValidNasBench201Benchmark(rng=1)

dataset_name = 'Fashion-MNIST'
bench_dir = "../LCBench/data/fashion_mnist.json"
LCBench = Benchmark(bench_dir, cache=False)

max_epoch = 200
sample_count = 1000
config_space = Bench.get_configuration_space(seed=1)
save_dict = dict()

for s in range(sample_count):
    save_dict[s] = dict()
    config = config_space.sample_configuration()
    lcb_config = LCBench.sample_config(dataset_name)
    train_losses = []
    valid_losses = []
    for epoch in range(1, max_epoch+1):
        fidelity = {'epoch': round(epoch)}
        result_dic = Bench.objective_function(configuration=config, fidelity=fidelity, data_seed=777)
        valid_losses.append(result_dic['info']['valid_losses'])
        train_losses.append(result_dic['info']['train_losses'])

    lcb_train_losses = LCBench.query(dataset_name=dataset_name, tag="Train/train_cross_entropy", config_id=lcb_config)
    lcb_valid_losses = LCBench.query(dataset_name=dataset_name, tag="Train/val_cross_entropy", config_id=lcb_config)

    save_dict[s]['train_losses'] = train_losses
    save_dict[s]['valid_losses'] = valid_losses
    save_dict[s]['lcb_train_losses'] = lcb_train_losses
    save_dict[s]['lcb_valid_losses'] = lcb_valid_losses


with open('save_dict.json', 'w') as file:
    json.dump(save_dict, file)

with open('save_dict.json', 'r') as file:
    save_dict = json.load(file)
num_lines = 10
color_palette = cm.get_cmap('tab20', num_lines)
line_colors = color_palette(range(num_lines))
for s in range(num_lines):
    train_losses = save_dict[str(s)]['train_losses']
    valid_losses = save_dict[str(s)]['valid_losses']

    lcb_train_losses = save_dict[str(s)]['lcb_train_losses']
    lcb_valid_losses = save_dict[str(s)]['lcb_valid_losses']

    axs[0].plot(valid_losses, color=line_colors[s])
    axs[1].plot(train_losses, color=line_colors[s])
    axs[2].plot(lcb_train_losses, color=line_colors[s])
    axs[3].plot(lcb_valid_losses, color=line_colors[s])

for ax in axs:
    ax.set_xlabel('Epoch')

axs[0].set_ylabel('Validation loss')
axs[1].set_ylabel('Training loss')
axs[2].set_ylabel('Validation loss')
axs[3].set_ylabel('Training loss')

axs[0].set_title('ImageNet-16-120')
axs[1].set_title('ImageNet-16-120')
axs[2].set_title('Fashion-MNIST')
axs[3].set_title('Fashion-MNIST')

# plt.legend(*zip(*labels), bbox_to_anchor=(-4.5,7.5), loc=3, ncol = 4, framealpha = 0)
plt.subplots_adjust(left=0.054, right=0.99, top=0.85, bottom=0.26, wspace=0.28)
plt.savefig(f'bench_overview.png')
plt.savefig(f'bench_overview.pdf')