import matplotlib.pyplot as plt
import numpy as np

################# Matplotlib Global Conf #########################
fontsize = 16

plt.rcParams["text.usetex"] = True
plt.rcParams["xtick.labelsize"] = fontsize - 2
plt.rcParams["ytick.labelsize"] = fontsize - 2
# plt.rcParams['ztick.labelsize'] = fontsize - 2
# plt.rcParams["xtick.major.pad"] = -1
# plt.rcParams["ytick.major.pad"] = -1
plt.rcParams["axes.labelsize"] = fontsize
plt.rcParams["axes.labelweight"] = "bold"

################# Matplotlib Global Conf #########################


state_mse_loss_file = "./logs/superencoder/v0.0.7.7/loss.txt"
state_fidelity_loss_file = "./logs/superencoder/v0.0.7/loss.txt"


def read_data(file_name):
    data = []
    with open(file_name, "r") as f:
        for line in f:
            data.append(float(line.strip()))
    return data


# read data from files
state_mse_loss = np.array(read_data(state_mse_loss_file))
state_fidelity_loss = np.array(read_data(state_fidelity_loss_file))

# normalize
state_mse_loss = state_mse_loss / np.max(state_mse_loss)
state_fidelity_loss = state_fidelity_loss / np.max(state_fidelity_loss)

# cut to same length
max_iter = min(len(state_mse_loss), len(state_fidelity_loss))
state_mse_loss = state_mse_loss[:max_iter]
state_fidelity_loss = state_fidelity_loss[:max_iter]

plt.figure()
plt.plot(state_mse_loss, label=r"state\_MSE", color="orange")
plt.plot(state_fidelity_loss, label=r"state\_fidelity", color="green")
plt.legend()
plt.ylabel(r"Loss")
plt.xlabel(r"Step")
plt.savefig("fractaldb_loss_comparison.pdf")
