from pathlib import Path
import numpy as np
import ase.io
from ase.filters import FrechetCellFilter

import matplotlib.pyplot as plt
import mpltex

Path("../figures").mkdir(exist_ok=True)


def binned_xy(x, y, nbins=20, lo=32, hi=68):
    q = np.linspace(0, 1, nbins + 1)
    edges = np.quantile(x, q)
    idx = np.digitize(x, edges[1:-1])
    xmid, y_c, y_lo, y_hi = [], [], [], []
    for b in range(nbins):
        sel = idx == b
        if np.any(sel):
            xmid.append(np.median(x[sel]))
            ys = y[sel]
            y_c.append(np.median(ys))
            y_lo.append(np.percentile(ys, lo))
            y_hi.append(np.percentile(ys, hi))
    return np.array(xmid), np.array(y_c), np.array(y_lo), np.array(y_hi)


colors = {"C_f": "C0", "NC_f": "C1", "C_c": "C2", "NC_c": "C3"}
titles = {
    "C_f": "C, free",
    "NC_f": "NC, free",
    "C_c": "C, constr.",
    "NC_c": "NC, constr.",
}

# Retrieve and prepare data

# Cell lengths and generalized forces from the trajectories
trajs = Path("../geometry_optimization/trajs")

eq_cell_length = {
    lattice: ase.io.trajectory.Trajectory(
        trajs / f"traj_constrained_{lattice}_ti_c.traj"
    )[-1].cell.lengths()
    for lattice in ["bcc", "hcp"]
}

cell_lengths = {}
fmax = {}

label_to_filename = {
    "C_f": "c_no_constr",
    "NC_f": "nc_no_constr",
    "C_c": "c",
}

for label in label_to_filename:
    cell_lengths[label] = {}
    fmax[label] = {}
    for lattice in ["bcc", "hcp"]:
        cl = []
        fm = []
        sym_label = "constrained" if "c" in label else "unconstrained"
        cons_label = "nc" if "NC" in label else "c"
        traj = ase.io.trajectory.Trajectory(
            trajs / f"traj_{sym_label}_{lattice}_ti_{cons_label}.traj"
        )
        for atoms in traj:
            cell = atoms.get_cell()
            lengths = cell.lengths()
            cl.append(lengths)

            ucf = FrechetCellFilter(atoms)
            fm.append(ucf.get_forces().max())

        cell_lengths[label][lattice] = np.array(cl)
        fmax[label][lattice] = np.array(fm)


@mpltex.acs_decorator
def plot_geom(frac=0.4, y_factor=1.0):

    fs = plt.rcParams["figure.figsize"]

    fig, axes = plt.subplots(
        2,
        1,
        dpi=300,
        figsize=(2 * fs[0] * frac, y_factor * fs[1]),
        gridspec_kw=dict(hspace=0.45, height_ratios=[1, 1]),
    )

    for label in label_to_filename:

        color = colors[label]

        for ax, lattice in zip(axes, ["bcc", "hcp"]):

            for i in range(3):
                x, y, _, _ = binned_xy(
                    fmax[label][lattice],
                    cell_lengths[label][lattice][:, i] - eq_cell_length[lattice][i],
                    100,
                )
                ax.plot(
                    x,
                    abs(y),
                    color=color,
                    ls="-",
                    lw=1,
                )

    for ax in axes:
        ax.set_xscale("log")
        ax.invert_xaxis()
        ax.set_xlim(1, 9e-6)
        ax.axhline(0.0, color="0.2", ls="--", lw=1, zorder=0)

    axes[0].set_ylim(-0.1, 1)
    axes[1].set_ylim(-0.1, 1)

    axes[1].set_xlabel(r"$|f_\mathrm{max}|$ [eV/$\mathrm{\AA}$]")
    axes[0].set_ylabel(
        "BCC" + "\n" + r"$|\Delta a|$ [$\mathrm{\AA}$]", horizontalalignment="center"
    )
    axes[1].set_ylabel(
        "HCP" + "\n" + r"$|\Delta a|$ [$\mathrm{\AA}$]", horizontalalignment="center"
    )

    axes[1].legend(
        handles=[
            plt.Line2D([0], [0], color=colors[ct], lw=2, label=titles[ct])
            for ct in label_to_filename
        ]
    )

    axes[0].text(0.02, 0.98, "a", transform=axes[0].transAxes, ha="left", va="top")
    axes[1].text(0.02, 0.98, "b", transform=axes[1].transAxes, ha="left", va="top")

    fig.savefig("../figures/geometry_optimization.svg", dpi=1200, transparent=True)
    fig.savefig("../figures/geometry_optimization.pdf", dpi=1200, bbox_inches="tight")

    return fig, axes


plot_geom(frac=0.25)
