import numpy as np
from pathlib import Path

import ase.io
import tqdm
import glob

import sys


split = sys.argv[1]


do_cell_and_stress = False

structures = ase.io.read(f"{split}.xyz", index=":")

root = Path(f"spice_{split}_mm")
root.mkdir(exist_ok=True)

N_path = root / "N.npy"
n_path = root / "n.npy"
a_path = root / "a.bin"
x_path = root / "x.bin"
c_path = root / "c.bin"
e_path = root / "e.bin"
f_path = root / "f.bin"
s_path = root / "s.bin"

N = len(structures)
n = np.cumsum(np.array([0] + [len(s) for s in structures], dtype=np.int64))
np.save(N_path, N)
np.save(n_path, n)

a_mm = np.memmap(a_path, dtype="int32", mode="w+", shape=(n[-1],))
x_mm = np.memmap(x_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: c_mm = np.memmap(c_path, dtype="float32", mode="w+", shape=(N, 3, 3))
e_mm = np.memmap(e_path, dtype="float32", mode="w+", shape=(N, 1))
f_mm = np.memmap(f_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: s_mm = np.memmap(s_path, dtype="float32", mode="w+", shape=(N, 3, 3))

for i, s in tqdm.tqdm(enumerate(structures)):
    a_mm[n[i]:n[i+1]] = s.numbers
    x_mm[n[i]:n[i+1]] = s.get_positions()
    if do_cell_and_stress: c_mm[i] = s.get_cell()[:]
    e_mm[i] = s.get_potential_energy()
    f_mm[n[i]:n[i+1]] = s.get_forces()
    if do_cell_and_stress: s_mm[i] = s.get_stress(voigt=False)

a_mm.flush()
x_mm.flush()
if do_cell_and_stress: c_mm.flush()
e_mm.flush()
f_mm.flush()
if do_cell_and_stress: s_mm.flush()
