import matplotlib.pyplot as plt
import numpy as np

################# Matplotlib Global Conf #########################
fontsize = 25

plt.rcParams["text.usetex"] = True
plt.rcParams["xtick.labelsize"] = fontsize - 2
plt.rcParams["ytick.labelsize"] = fontsize - 2
# plt.rcParams['ztick.labelsize'] = fontsize - 2
# plt.rcParams["xtick.major.pad"] = -1
# plt.rcParams["ytick.major.pad"] = -1
plt.rcParams["axes.labelsize"] = fontsize
plt.rcParams["axes.labelweight"] = "bold"

################# Matplotlib Global Conf #########################


mse_loss_file = "./logs/superencoder/MSE_landscape/loss.txt"
state_mse_loss_file = "./logs/superencoder/v0.0.7.11/loss.txt"
state_fidelity_loss_file = "./logs/superencoder/v0.0.7.10/loss.txt"


def read_data(file_name):
    data = []
    with open(file_name, "r") as f:
        for line in f:
            data.append(float(line.strip()))
    return data


# read data from files
mse_loss = np.array(read_data(mse_loss_file))
state_mse_loss = np.array(read_data(state_mse_loss_file))
state_fidelity_loss = np.array(read_data(state_fidelity_loss_file))

# normalize
### mse_loss = mse_loss / np.max(mse_loss)
### state_mse_loss = state_mse_loss / np.max(state_mse_loss)
### state_fidelity_loss = state_fidelity_loss / np.max(state_fidelity_loss)

# cut to same length
max_iter = min(len(mse_loss), len(state_mse_loss), len(state_fidelity_loss))
mse_loss = mse_loss[:max_iter]
state_mse_loss = state_mse_loss[:max_iter]
state_fidelity_loss = state_fidelity_loss[:max_iter]

plt.figure(figsize=(6, 4))
plt.plot(mse_loss, label=r"$\mathcal{L}_1$", color="blue")
plt.plot(state_mse_loss, label=r"$\mathcal{L}_2$", color="orange")
plt.plot(state_fidelity_loss, label=r"$\mathcal{L}_3$", color="green")
plt.legend(fontsize=fontsize - 6)
plt.ylabel(r"Loss")
plt.xlabel(r"Step")
plt.xticks([0, 200, 400, 600, 800], fontsize=fontsize - 6)
plt.yticks([0, 0.25, 0.5, 0.75, 1], fontsize=fontsize - 6)
plt.tight_layout()
plt.savefig("mnist_loss_comparison.pdf")
