from pathlib import Path
from ase.build import bulk
import numpy as np
from ase.optimize import FIRE
from ase.filters import FrechetCellFilter
from ase.constraints import FixSymmetry


models_dir = Path("../../models")
c_model = "model-oam-c.pt"
nc_model = "model-oam-nc.pt"


def get_calc(
    model_path: str,
    non_conservative: bool = False,
    device="cuda",
    dtype="float64",
    impose_space_group_symmetry: bool = False,
    equivariant_up_to_l_max: int = 0,
):
    from metatomic.torch.ase_calculator import MetatomicCalculator, O3AveragedCalculator
    from metatomic.torch import load_atomistic_model
    import torch

    model = load_atomistic_model(model_path)
    model.capabilities().dtype = dtype
    model = model.to(
        dtype=torch.float64 if dtype == "float64" else torch.float32, device=device
    )
    calc = MetatomicCalculator(model, device=device, non_conservative=non_conservative)

    if equivariant_up_to_l_max > 0 or impose_space_group_symmetry:
        calc = O3AveragedCalculator(
            calc,
            l_max=equivariant_up_to_l_max,
            apply_group_symmetry=impose_space_group_symmetry,
            batch_size=16,
        )

    return calc


unit_cell = bulk("Ti", "hcp", a=3.32, c=5.2)

calc = get_calc(
    models_dir / c_model,
    non_conservative=False,
)

calc_nc = get_calc(
    models_dir / nc_model,
    non_conservative=True,
)
calc_c_as_nc = get_calc(
    models_dir / c_model,
    non_conservative=True,
)


opt_atoms = []
for c in [calc, calc_nc, calc_c_as_nc]:
    atoms = unit_cell.copy()
    atoms.calc = c
    atoms.set_constraint(FixSymmetry(atoms))
    ucf = FrechetCellFilter(atoms, mask=3 * [True] + 3 * [False])
    dyn = FIRE(ucf)
    dyn.run(fmax=1e-4, steps=300)
    atoms.constraints = None
    opt_atoms.append(atoms.copy())

forces = []
forces_nc = []
forces_c_as_nc = []


def move_and_get_forces(atoms, calc, index, displacement):
    atm = atoms.copy().repeat((4, 4, 4))
    atm.positions[index, 2] += displacement
    atm.calc = calc
    return atm.get_forces()


for dx in np.linspace(-0.15, 0.15, 100):
    forces.append(move_and_get_forces(opt_atoms[0], calc, 0, dx))
    forces_nc.append(move_and_get_forces(opt_atoms[1], calc_nc, 0, dx))
    forces_c_as_nc.append(move_and_get_forces(opt_atoms[2], calc_c_as_nc, 0, dx))

forces = np.array(forces)
forces_nc = np.array(forces_nc)
forces_c_as_nc = np.array(forces_c_as_nc)

Path("data").mkdir(exist_ok=True)
np.save("data/forces_pet-oam-c.npy", forces)
np.save("data/forces_pet-oam-nc.npy", forces_nc)
np.save("data/forces_pet-oam-c-as-nc.npy", forces_c_as_nc)
