import os
import logging
from pathlib import Path

import numpy as np
from ase.build import bulk
from ase.constraints import FixSymmetry
from ase.filters import FrechetCellFilter
from ase.optimize import FIRE

# ---------------------------------------
# Logging
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
logging.basicConfig(
    level=getattr(logging, LOGLEVEL, logging.INFO),
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("Phonons")

# ---------------------------------------

models_dir = Path("../models")
c_model = "model-oam-c.pt"
nc_model = "model-oam-nc.pt"
ph_dir = Path("./ase_phonons")

logger.info(f"Models directory: {models_dir.resolve()}")
logger.info(f"Conservative model: {c_model}")
logger.info(f"Non-conservative model: {nc_model}")
logger.info(f"Phonons output directory: {ph_dir.resolve()}")


# ---------------------------------------
# Helpers
def compute_phonons(
    atoms,
    calc,
    supercell=(3, 3, 3),
    delta=0.01,
    symmetrize=3,
    acoustic=True,
):
    from ase.phonons import Phonons

    formula = atoms.get_chemical_formula()
    logger.info(
        f"Computing phonons for {formula}: supercell={supercell}, delta={delta}"
    )

    phonons = Phonons(atoms, calc, supercell=supercell, delta=delta)
    try:
        phonons.clean()
        logger.debug("Previous phonon data cleaned.")
    except Exception:
        logger.debug("No previous phonon data to clean.")
    phonons.run()
    phonons.read(acoustic=acoustic, symmetrize=symmetrize)
    phonons.clean()

    phonons.atoms.calc = None

    return phonons


def optimize_with_constraints(atoms, calc, fmax=1e-8, steps=500, trajectory=None):
    formula = atoms.get_chemical_formula()
    logger.info(
        f"Optimizing (with symmetry constraints) {formula}: fmax={fmax}, steps={steps}"
    )
    atm = atoms.copy()
    atm.calc = calc
    atm.set_constraint(FixSymmetry(atm))
    ucf = FrechetCellFilter(atm, mask=3 * [True] + 3 * [False], hydrostatic_strain=True)
    opt = FIRE(ucf, trajectory=trajectory)
    opt.run(fmax=fmax, steps=steps)
    try:
        energy = atm.get_potential_energy()
        logger.info(
            f"Optimization (constrained) finished for {formula}: "
            f"E={energy:.6f} eV, V={atm.get_volume():.3f} Å^3"
        )
    except Exception as e:
        logger.warning(f"Could not compute energy for {formula}: {e}")
    atm.constraints = None
    return atm


def optimize_without_constraints(atoms, calc, fmax=1e-8, steps=500, trajectory=None):
    formula = atoms.get_chemical_formula()
    logger.info(f"Optimizing (unconstrained) {formula}: fmax={fmax}, steps={steps}")
    atm = atoms.copy()
    atm.calc = calc
    ucf = FrechetCellFilter(atm)
    opt = FIRE(ucf, trajectory=trajectory)
    opt.run(fmax=fmax, steps=steps)
    try:
        energy = atm.get_potential_energy()
        logger.info(
            f"Optimization (unconstrained) finished for {formula}: "
            f"E={energy:.6f} eV, V={atm.get_volume():.3f} Å^3"
        )
    except Exception as e:
        logger.warning(f"Could not compute energy for {formula}: {e}")
    return atm


def get_calc(
    model_path: str,
    non_conservative: bool = False,
    device="cuda",
    dtype="float64",
):
    from metatomic.torch.ase_calculator import MetatomicCalculator
    from metatomic.torch import load_atomistic_model
    import torch

    logger.info(
        f"Loading model from {model_path} | "
        f"non_conservative={non_conservative}, device={device}, dtype={dtype}"
    )
    model = load_atomistic_model(model_path)
    model.capabilities().dtype = dtype
    model = model.to(
        dtype=torch.float64 if dtype == "float64" else torch.float32, device=device
    )
    calc = MetatomicCalculator(model, device=device, non_conservative=non_conservative)
    logger.info("Calculator initialized.")
    return calc


# ---------------------------------------

model_path_c = models_dir / c_model
model_path_nc = models_dir / nc_model

calc_c = get_calc(
    model_path=model_path_c,
    non_conservative=False,
    device="cpu",
    dtype="float64",
)
calc_nc = get_calc(
    model_path=model_path_nc,
    non_conservative=True,
    device="cpu",
    dtype="float64",
)

bcc_Ti = bulk("Ti", "bcc", a=3.32)
hcp_Ti = bulk("Ti", "hcp", a=3.32, c=5.2)
logger.info("Initial structures prepared: BCC Ti and HCP Ti.")

supercell = (5, 5, 5)
fmax = 1e-8
steps = 300
logger.info(f"Global settings: supercell={supercell}, fmax={fmax}, steps={steps}")

# Starting from BCC
logger.info("Starting optimizations from BCC Ti.")
bcc_Ti_constrained_c = optimize_with_constraints(bcc_Ti, calc_c, fmax=fmax, steps=steps)
bcc_Ti_unconstrained_c = optimize_without_constraints(
    bcc_Ti, calc_c, fmax=fmax, steps=steps
)
bcc_Ti_unconstrained_nc = optimize_without_constraints(
    bcc_Ti, calc_nc, fmax=fmax, steps=steps
)

# Starting from HCP
logger.info("Starting optimizations from HCP Ti.")
hcp_Ti_constrained_c = optimize_with_constraints(hcp_Ti, calc_c, fmax=fmax, steps=steps)
hcp_Ti_unconstrained_c = optimize_without_constraints(
    hcp_Ti, calc_c, fmax=fmax, steps=steps
)
hcp_Ti_unconstrained_nc = optimize_without_constraints(
    hcp_Ti, calc_nc, fmax=fmax, steps=steps
)

# Phonons
logger.info("Computing phonons for BCC Ti (constrained, conservative).")
ph_bcc_Ti_constrained_c = compute_phonons(
    bcc_Ti_constrained_c, calc_c, supercell=supercell, delta=0.03
)
logger.info("Computing phonons for BCC Ti (unconstrained, conservative).")
ph_bcc_Ti_unconstrained_c = compute_phonons(
    bcc_Ti_unconstrained_c, calc_c, supercell=supercell, delta=0.03
)
logger.info("Computing phonons for BCC Ti (unconstrained, non-conservative).")
ph_bcc_Ti_unconstrained_nc = compute_phonons(
    bcc_Ti_unconstrained_nc, calc_nc, supercell=supercell, delta=0.1
)

logger.info("Computing phonons for HCP Ti (constrained, conservative).")
ph_hcp_Ti_constrained_c = compute_phonons(
    hcp_Ti_constrained_c, calc_c, supercell=supercell, delta=0.03
)
logger.info("Computing phonons for HCP Ti (unconstrained, conservative).")
ph_hcp_Ti_unconstrained_c = compute_phonons(
    hcp_Ti_unconstrained_c, calc_c, supercell=supercell, delta=0.03
)
logger.info("Computing phonons for HCP Ti (unconstrained, non-conservative).")
ph_hcp_Ti_unconstrained_nc = compute_phonons(
    hcp_Ti_unconstrained_nc, calc_nc, supercell=supercell, delta=0.1
)

all_phonons = [
    ph_bcc_Ti_constrained_c,
    ph_bcc_Ti_unconstrained_c,
    ph_bcc_Ti_unconstrained_nc,
    ph_hcp_Ti_constrained_c,
    ph_hcp_Ti_unconstrained_c,
    ph_hcp_Ti_unconstrained_nc,
]

all_phonons_labels = [
    "pet-oam-bcc-C-constrained",
    "pet-oam-bcc-C-unconstrained",
    "pet-oam-bcc-NC-unconstrained",
    "pet-oam-hcp-C-constrained",
    "pet-oam-hcp-C-unconstrained",
    "pet-oam-hcp-NC-unconstrained",
]

ph_dir.mkdir(exist_ok=True)
logger.info(f"Output directory ensured: {ph_dir.resolve()}")

for ph, lbl in zip(all_phonons, all_phonons_labels):
    ph.calc = None
    out_path = ph_dir / f"ph_{lbl}.npy"
    np.save(out_path, ph, allow_pickle=True)
    logger.info(f"Saved phonons: {out_path}")
