import os
import logging
from pathlib import Path

from ase.build import bulk
from ase.constraints import FixSymmetry
from ase.filters import FrechetCellFilter
from ase.optimize import FIRE
from tqdm import tqdm

# ---------------------------------------
# Logging
LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper()
logging.basicConfig(
    level=getattr(logging, LOGLEVEL, logging.INFO),
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("Geometry Optimization")

# ---------------------------------------


models_dir = Path("../models")
c_model = "model-oam-c.pt"
nc_model = "model-oam-nc.pt"

trajs = Path("trajs")
xyz = Path("xyz")
logs = Path("logs")


# Helpers
class TQDMLogger:
    def __init__(self, pbar):
        self.pbar = pbar

    def __call__(self, *args, **kwargs):
        self.pbar.update(1)


def optimize_with_constraints(atoms, calc, fmax=1e-8, steps=500, trajectory=None):
    atm = atoms.copy()
    atm.calc = calc
    atm.set_constraint(FixSymmetry(atm))
    ucf = FrechetCellFilter(atm, mask=3 * [True] + 3 * [False])

    logfile = None
    if trajectory is not None:
        log_name = Path(trajectory).with_suffix(".log").name
        logfile = str(logs / log_name)

    logger.info(
        (
            f"Starting constrained optimization: natoms={len(atm)}, "
            f"fmax={fmax:.3e}, steps={steps}, traj={trajectory}"
        )
    )
    opt = FIRE(ucf, trajectory=trajectory, logfile=logfile)
    pbar = tqdm(total=steps, desc="Geometry optimization", unit="step")
    opt.attach(TQDMLogger(pbar), interval=1)
    opt.run(fmax=fmax, steps=steps)
    atm.constraints = None

    try:
        energy = atm.get_potential_energy()
        fmax_val = abs(atm.get_forces()).max()
        stress_max = abs(atm.get_stress()).max()
        logger.info(
            (
                f"Finished constrained optimization: E={energy:.6f} eV, "
                f"max|F|={fmax_val:.3e} eV/Å, max|stress|={stress_max:.3e} eV/Å"
            )
        )
    except Exception as e:
        logger.warning(f"Could not compute final energy/forces: {e}")

    return atm


def optimize_without_constraints(atoms, calc, fmax=1e-8, steps=500, trajectory=None):

    atm = atoms.copy()
    atm.calc = calc
    ucf = FrechetCellFilter(atm)

    logfile = None
    if trajectory is not None:
        log_name = Path(trajectory).with_suffix(".log").name
        logfile = str(logs / log_name)

    logger.info(
        (
            f"Starting unconstrained optimization: natoms={len(atm)}, "
            f"fmax={fmax:.3e}, steps={steps}, traj={trajectory}"
        )
    )
    opt = FIRE(ucf, trajectory=trajectory, logfile=logfile)
    pbar = tqdm(total=steps, desc="Geometry optimization", unit="step")
    opt.attach(TQDMLogger(pbar), interval=1)
    opt.run(fmax=fmax, steps=steps)

    try:
        energy = atm.get_potential_energy()
        fmax_val = abs(atm.get_forces()).max()
        stress_max = abs(atm.get_stress()).max()
        logger.info(
            (
                f"Finished unconstrained optimization: E={energy:.6f} eV, "
                f"max|F|={fmax_val:.3e} eV/Å, max|stress|={stress_max:.3e} eV/Å"
            )
        )
    except Exception as e:
        logger.warning(f"Could not compute final energy/forces: {e}")

    return atm


def get_calc(
    model_path: str,
    non_conservative: bool = False,
    device="cuda",
    dtype="float64",
):
    from metatomic.torch.ase_calculator import MetatomicCalculator
    from metatomic.torch import load_atomistic_model
    import torch

    logger.info(
        (
            f"Loading model: path={model_path}, non_conservative={non_conservative}, "
            f"device={device}, dtype={dtype}"
        )
    )
    model = load_atomistic_model(model_path)
    model.capabilities().dtype = dtype
    model = model.to(
        dtype=torch.float64 if dtype == "float64" else torch.float32, device=device
    )
    calc = MetatomicCalculator(model, device=device, non_conservative=non_conservative)
    logger.info("Calculator initialized")
    return calc


bcc_Ti = bulk("Ti", "bcc", a=3.0, cubic=True)
hcp_Ti = bulk("Ti", "hcp", a=3.0, c=5.0)

model_path_c = models_dir / c_model
model_path_nc = models_dir / nc_model

calc_types = ["c", "nc"]
calc_setups = [
    dict(
        model_path=model_path_c,
        non_conservative=False,
        device="cuda",
        dtype="float64",
    ),
    dict(
        model_path=model_path_nc,
        non_conservative=True,
        device="cuda",
        dtype="float64",
    ),
]

fmax = 1e-15
steps = 300

trajs.mkdir(exist_ok=True)
logs.mkdir(exist_ok=True)
xyz.mkdir(exist_ok=True)

for calc_type, calc_setup in zip(calc_types, calc_setups):
    try:
        calc = get_calc(**calc_setup)
    except Exception as e:
        logger.exception(f"Failed to initialize calculator for {calc_type}: {e}")
        continue

    try:
        logger.info(f"Constrained minimization of BCC Ti with {calc_type} calculator")
        opt_bcc = optimize_with_constraints(
            bcc_Ti,
            calc,
            fmax=fmax,
            steps=steps,
            trajectory=trajs / f"traj_constrained_bcc_ti_{calc_type}.traj",
        )
        opt_bcc.write(xyz / f"opt_constrained_bcc_ti_{calc_type}.xyz")
    except Exception as e:
        logger.exception(f"Failed constrained BCC optimization for {calc_type}: {e}")

    try:
        logger.info(f"Unconstrained minimization of BCC Ti with {calc_type} calculator")
        opt_bcc_no_constr = optimize_without_constraints(
            bcc_Ti,
            calc,
            fmax=fmax,
            steps=steps,
            trajectory=trajs / f"traj_unconstrained_bcc_ti_{calc_type}.traj",
        )
        opt_bcc_no_constr.write(xyz / f"opt_unconstrained_bcc_ti_{calc_type}.xyz")
    except Exception as e:
        logger.exception(f"Failed unconstrained BCC optimization for {calc_type}: {e}")

    try:
        logger.info(f"Constrained minimization of HCP Ti with {calc_type} calculator")
        opt_hcp = optimize_with_constraints(
            hcp_Ti,
            calc,
            fmax=fmax,
            steps=steps,
            trajectory=trajs / f"traj_constrained_hcp_ti_{calc_type}.traj",
        )
        opt_hcp.write(xyz / f"opt_constrained_hcp_ti_{calc_type}.xyz")
    except Exception as e:
        logger.exception(f"Failed constrained HCP optimization for {calc_type}: {e}")

    try:
        logger.info(f"Unconstrained minimization of HCP Ti with {calc_type} calculator")
        opt_hcp_no_constr = optimize_without_constraints(
            hcp_Ti,
            calc,
            fmax=fmax,
            steps=steps,
            trajectory=trajs / f"traj_unconstrained_hcp_ti_{calc_type}.traj",
        )
        opt_hcp_no_constr.write(xyz / f"opt_unconstrained_hcp_ti_{calc_type}.xyz")
    except Exception as e:
        logger.exception(f"Failed unconstrained HCP optimization for {calc_type}: {e}")
