"""This scripts plot figures showing how loss of AAE change with increasing circuit depth"""

import json
import sys

import matplotlib.pyplot as plt

################# Matplotlib Global Conf #########################
fontsize = 25

plt.rcParams["text.usetex"] = True
plt.rcParams["xtick.labelsize"] = fontsize - 2
plt.rcParams["ytick.labelsize"] = fontsize - 2
# plt.rcParams['ztick.labelsize'] = fontsize - 2
# plt.rcParams["xtick.major.pad"] = -1
# plt.rcParams["ytick.major.pad"] = -1
plt.rcParams["axes.labelsize"] = fontsize
plt.rcParams["axes.labelweight"] = "bold"

################# Matplotlib Global Conf #########################

if len(sys.argv) != 2:
    print(
        f"Usage: python {sys.argv[0]} <test-result-json> # e.g., logs/eval/aae_depth/4qubits.json"
    )
    sys.exit(1)

res_json_path = sys.argv[1]

with open(res_json_path, "r") as f:
    res_list = json.load(f)

num_layer_list = [item["num_encoder_layers"] for item in res_list]
fidelity_list = [item["fidelity"] for item in res_list]
num_qubit_list = [item["n_qubits"] for item in res_list]

print(f"Number of qubits: {num_qubit_list[0]}\n")

plt.figure(figsize=(10, 5))
plt.xlabel("Number of Blocks")
plt.ylabel("Fidelity")
plt.xticks(num_layer_list)
plt.plot(num_layer_list, fidelity_list, linewidth=2, marker="X")
plt.tight_layout()
plt.grid()
plt.savefig(f"aae_fid_depth_{num_qubit_list[0]}_qubits.pdf")
