from pathlib import Path
import numpy as np

import matplotlib.pyplot as plt
import mpltex

import ase.io


def plot_phonon(
    phonons,
    bandpath=None,
    ax=None,
    emin=0.0,
    emax=None,
    color=None,
    atoms_for_path=None,
):

    if ax is None:
        _, ax = plt.subplots()

    if not isinstance(phonons, list):
        phonons = [phonons]

    for ph in phonons:
        try:
            if atoms_for_path is not None:
                path = atoms_for_path.cell.bandpath(bandpath, npoints=200)
            else:
                path = ph.atoms.cell.bandpath(bandpath, npoints=200)
        except KeyError:
            path = ph.atoms.cell.bandpath(npoints=200)
        bs = ph.get_band_structure(path)
        bs.plot(ax=ax, emin=emin, emax=emax, color=color, lw=0.75)

    return ax


colors = {"C_f": "C0", "NC_f": "C1", "C_c": "C2", "NC_c": "C3"}
titles = {
    "C_f": "C, free",
    "NC_f": "NC, free",
    "C_c": "C, constr.",
    "NC_c": "NC, constr.",
}

xyz = Path("../geometry_optimization/xyz")
assert xyz.is_dir()

bcc_Ti = ase.io.read(xyz / "opt_constrained_bcc_ti_c.xyz")
hcp_Ti = ase.io.read(xyz / "opt_constrained_hcp_ti_c.xyz")

phonons_dir = Path("../phonons/ase_phonons")
assert phonons_dir.is_dir()

ph_bcc_Ti = np.load(
    phonons_dir / "ph_pet-oam-bcc-C-constrained.npy", allow_pickle=True
).item()
ph_hcp_Ti = np.load(
    phonons_dir / "ph_pet-oam-hcp-C-constrained.npy", allow_pickle=True
).item()

ph_bcc_Ti_unc = np.load(
    phonons_dir / "ph_pet-oam-bcc-C-unconstrained.npy", allow_pickle=True
).item()
ph_hcp_Ti_unc = np.load(
    phonons_dir / "ph_pet-oam-hcp-C-unconstrained.npy", allow_pickle=True
).item()

ph_bcc_Ti_unc_nc = np.load(
    phonons_dir / "ph_pet-oam-bcc-NC-unconstrained.npy", allow_pickle=True
).item()
ph_hcp_Ti_unc_nc = np.load(
    phonons_dir / "ph_pet-oam-hcp-NC-unconstrained.npy", allow_pickle=True
).item()


all_phonons = {
    "BCC_C_c": ph_bcc_Ti,
    "BCC_C_f": ph_bcc_Ti_unc,
    "BCC_NC_f": ph_bcc_Ti_unc_nc,
    "HCP_C_c": ph_hcp_Ti,
    "HCP_C_f": ph_hcp_Ti_unc,
    "HCP_NC_f": ph_hcp_Ti_unc_nc,
}


@mpltex.acs_decorator
def plot_ph(frac=0.6, y_factor=1.0):
    fs = plt.rcParams["figure.figsize"]
    fig, axes = plt.subplots(
        2,
        4,
        figsize=(2 * fs[0] * frac, y_factor * fs[1]),
        dpi=250,
        sharey=True,
        gridspec_kw={"width_ratios": [1, 1, 1, 0.5], "wspace": 0.15, "hspace": 0.45},
    )

    ph_axes = [axes[0, 0], axes[0, 1], axes[0, 2], axes[1, 0], axes[1, 1], axes[1, 2]]
    dos_axes = axes[:, 3]

    from pymatgen.core import Structure

    bcc_prim = Structure.from_ase_atoms(ph_bcc_Ti.atoms).to_primitive().to_ase_atoms()
    hcp_prim = Structure.from_ase_atoms(ph_hcp_Ti.atoms).to_primitive().to_ase_atoms()

    paths = ["GHNGPH", "GHNGPH", "GHNGPH", "GMKGALHA", "GMKGALHA", "GMKGALHA"]
    for ph_key, ax, path, ph_atoms in zip(
        all_phonons,
        ph_axes,
        paths,
        [bcc_prim, bcc_prim, bcc_prim, hcp_prim, hcp_prim, hcp_prim],
    ):

        ph = all_phonons[ph_key]
        color = colors[ph_key[4:]]

        plot_phonon(
            ph,
            bandpath=path,
            ax=ax,
            emin=-0.005,
            emax=0.04,
            color=color,
            atoms_for_path=ph_atoms,
        )
        # ax.set_title(label)
        ax.set_ylabel("")
    axes[0, 0].set_ylabel("Energy (eV)")
    axes[1, 0].set_ylabel("Energy (eV)")

    # Plot DOS
    kpts = (20, 20, 20)
    npts = 200
    width = 1e-3
    # BCC-like
    ax = dos_axes[0]
    for ph_key, ls in zip(list(all_phonons.keys())[:3], ["-", "--", ":"]):
        ph = all_phonons[ph_key]
        color = colors[ph_key[4:]]
        dos = ph.get_dos(kpts=kpts).sample_grid(npts=npts, width=width)
        ax.plot(dos.get_weights(), dos.get_energies(), color=color, ls=ls)

    # HCP-like
    ax = dos_axes[1]
    for ph_key, ls in zip(list(all_phonons.keys())[3:], ["-", "--", ":"]):
        ph = all_phonons[ph_key]
        color = colors[ph_key[4:]]
        dos = ph.get_dos(kpts=kpts).sample_grid(npts=npts, width=width)
        ax.plot(dos.get_weights(), dos.get_energies(), color=color, ls=ls)

    for ax in dos_axes:
        ax.set_xlim(0)
        ax.set_xlabel("VDOS")
        ax.set_xticks([])
        ax.axhline(0, color="k", ls=":")

    for ax, ph_key in zip(ph_axes, all_phonons):
        label = titles[ph_key[4:]]
        ax.text(
            0.95,
            0.95,
            label,
            transform=ax.transAxes,
            ha="right",
            va="top",
            fontsize=7,
            bbox=dict(
                boxstyle="square,pad=0.2",
                fc="white",
                alpha=1,
                edgecolor="k",
                linewidth=0.1,
            ),
        )
        ax.set_yticks([0, 0.02, 0.04])

    axes[0, 1].text(
        0.5, -0.33, "BCC paths", transform=axes[0, 1].transAxes, ha="center"
    )
    axes[1, 1].text(
        0.5, -0.33, "HCP paths", transform=axes[1, 1].transAxes, ha="center"
    )

    for ax, letter in zip(axes.flatten(), ["c", "d", "e", "f", "g", "h", "i", "j"]):
        ax.text(0.02, 0.98, letter, transform=ax.transAxes, ha="left", va="top")

    figname = "../figures/phonons"
    fig.savefig(f"../figures/{figname}.svg", dpi=1200, transparent=True)
    fig.savefig(f"../figures/{figname}.pdf", dpi=1200, bbox_inches="tight")

    return fig


fig = plot_ph(frac=0.75)
