from pathlib import Path
import numpy as np

import matplotlib.pyplot as plt
import mpltex

import ase.io
from pymatgen.core import Structure


def plot_phonon(
    phonons,
    bandpath=None,
    ax=None,
    emin=0.0,
    emax=None,
    color=None,
    atoms_for_path=None,
):

    if ax is None:
        _, ax = plt.subplots()

    if not isinstance(phonons, list):
        phonons = [phonons]

    for ph in phonons:
        try:
            if atoms_for_path is not None:
                path = atoms_for_path.cell.bandpath(bandpath, npoints=200)
            else:
                path = ph.atoms.cell.bandpath(bandpath, npoints=200)
        except KeyError:
            path = ph.atoms.cell.bandpath(npoints=200)
        path = atoms_for_path.cell.bandpath(path.path[:6], npoints=200)
        bs = ph.get_band_structure(path)
        bs.plot(ax=ax, emin=emin, emax=emax, color=color, lw=0.75)

    return ax


colors = {"C_f": "C0", "NC_f": "C1", "C_c": "C2", "NC_c": "C3"}
titles = {
    "C_f": "C, free",
    "NC_f": "NC, free",
    "C_c": "C, constr.",
    "NC_c": "NC, constr.",
}

xyz = Path("../geometry_optimization/xyz")
assert xyz.is_dir()

bcc_Ti = ase.io.read(xyz / "opt_constrained_bcc_ti_c.xyz")
hcp_Ti = ase.io.read(xyz / "opt_constrained_hcp_ti_c.xyz")

phonons_dir = Path("../phonons/ase_phonons")
assert phonons_dir.is_dir()

ph_bcc_Ti = np.load(
    phonons_dir / "ph_pet-oam-bcc-C-constrained.npy", allow_pickle=True
).item()
ph_hcp_Ti = np.load(
    phonons_dir / "ph_pet-oam-hcp-C-constrained.npy", allow_pickle=True
).item()

ph_bcc_Ti_unc = np.load(
    phonons_dir / "ph_pet-oam-bcc-C-unconstrained.npy", allow_pickle=True
).item()
ph_hcp_Ti_unc = np.load(
    phonons_dir / "ph_pet-oam-hcp-C-unconstrained.npy", allow_pickle=True
).item()

ph_bcc_Ti_unc_nc = np.load(
    phonons_dir / "ph_pet-oam-bcc-NC-unconstrained.npy", allow_pickle=True
).item()
ph_hcp_Ti_unc_nc = np.load(
    phonons_dir / "ph_pet-oam-hcp-NC-unconstrained.npy", allow_pickle=True
).item()


all_phonons = {
    "BCC_C_c": ph_bcc_Ti,
    "BCC_C_f": ph_bcc_Ti_unc,
    "BCC_NC_f": ph_bcc_Ti_unc_nc,
    "HCP_C_c": ph_hcp_Ti,
    "HCP_C_f": ph_hcp_Ti_unc,
    "HCP_NC_f": ph_hcp_Ti_unc_nc,
}


@mpltex.acs_decorator
def plot_ph(frac=0.6, y_factor=1.0):
    fs = plt.rcParams["figure.figsize"]
    fig, axes = plt.subplots(
        2,
        3,
        figsize=(2 * fs[0] * frac, y_factor * fs[1]),
        dpi=250,
        sharey=True,
        gridspec_kw={"wspace": 0.15, "hspace": 0.45},
    )

    atoms_for_path = [
        Structure.from_ase_atoms(ph.atoms).to_primitive(symprec=1e-2).to_ase_atoms()
        for ph in all_phonons.values()
    ]

    for ph_key, ax, ph_atoms in zip(
        all_phonons,
        axes.flatten(),
        atoms_for_path,
    ):

        ph = all_phonons[ph_key]
        color = colors[ph_key[4:]]

        plot_phonon(
            ph,
            ax=ax,
            emin=-0.005,
            emax=0.04,
            color=color,
            atoms_for_path=ph_atoms,
        )
        ax.set_ylabel("")
    axes[0, 0].set_ylabel("Energy (eV)")
    axes[1, 0].set_ylabel("Energy (eV)")

    for ax, ph_key in zip(axes.flatten(), all_phonons):
        label = titles[ph_key[4:]]
        ax.text(
            0.95,
            0.95,
            label,
            transform=ax.transAxes,
            ha="right",
            va="top",
            fontsize=7,
            bbox=dict(
                boxstyle="square,pad=0.2",
                fc="white",
                alpha=1,
                edgecolor="k",
                linewidth=0.1,
            ),
        )
        ax.set_yticks([0, 0.02, 0.04])

    path_names = ["BCC", "FCC", "FCO", "HCP", "HCP", "BCO"]
    for ax, path_name, letter in zip(
        axes.flatten(), path_names, ["b", "c", "d", "e", "f", "g", "h", "i"]
    ):
        ax.set_xlabel(path_name + " paths")
        ax.text(0.02, 0.98, letter, transform=ax.transAxes, ha="left", va="top")

    figname = "../figures/phonons_SI"
    fig.savefig(f"../figures/{figname}.svg", dpi=1200, transparent=True)
    fig.savefig(f"../figures/{figname}.pdf", dpi=1200, bbox_inches="tight")

    return fig


fig = plot_ph(frac=0.6)
