import matplotlib.pyplot as plt
import numpy as np
import mpltex

forces = np.load("../appendix/force_vs_displacement/data//forces_pet-oam-c.npy")
forces_nc = np.load("../appendix/force_vs_displacement/data//forces_pet-oam-nc.npy")
forces_c_as_nc = np.load(
    "../appendix/force_vs_displacement/data//forces_pet-oam-c-as-nc.npy"
)


@mpltex.acs_decorator
def plot_force(frac=0.3):
    fs = plt.rcParams["figure.figsize"]
    fig, ax = plt.subplots(dpi=150, figsize=(2 * fs[0] * frac, fs[1] * 1))

    dx = np.linspace(-0.15, 0.15, 100)

    ax.plot(dx, np.linalg.norm(forces[:, 0], axis=1), label="C", color="C0")
    ax.plot(
        dx,
        np.linalg.norm(forces_c_as_nc[:, 0], axis=1),
        label="NC (trained as C)",
        color="black",
    )
    ax.plot(dx, np.linalg.norm(forces_nc[:, 0], axis=1), label="NC", color="C1")

    ax.axhline(0, color="0.2", ls="--")
    ax.axvline(0, color="0.2", ls="--")
    ax.set_xlabel(r"Displacement of atom 1 [$\mathrm{\AA}$]")
    ax.set_ylabel(r"$|\mathbf{f}_1|$ [eV/$\mathrm{\AA}$]")
    ax.legend(frameon=True, framealpha=1)

    ax.set_xlim(-0.1, 0.1)
    ax.set_ylim(0, 0.65)

    ax.text(0.02, 0.98, "a", transform=ax.transAxes, ha="left", va="top")

    fig.savefig(
        "../figures/forces_close_to_equilibrium.pdf", dpi=1200, bbox_inches="tight"
    )
    fig.savefig(
        "../figures/forces_close_to_equilibrium.svg", dpi=1200, transparent=True
    )

    return fig, ax


plot_force(frac=0.4)
