import numpy as np
from pathlib import Path

import ase.io
import tqdm

from fairchem.core.datasets import AseDBDataset
import glob
import tqdm


all_aselmdb_files = glob.glob(f"./*/*.aselmdb")
structure_count = 0
atom_counts = []
for dataset_path in all_aselmdb_files:
    try:
        print(f"Reading {dataset_path}")
        dataset = AseDBDataset(config=dict(src=dataset_path))
    except Exception as e:
        print(f"Failed to load dataset {dataset_path}: {e}")
        continue

    for index in range(len(dataset)):
        atoms = dataset.get_atoms(index)
        structure_count += 1
        atom_counts.append(len(atoms))

do_cell_and_stress = True

root = Path("omat_train_mm")
root.mkdir(exist_ok=True)

N_path = root / "N.npy"
n_path = root / "n.npy"
a_path = root / "a.bin"
x_path = root / "x.bin"
c_path = root / "c.bin"
e_path = root / "e.bin"
f_path = root / "f.bin"
s_path = root / "s.bin"

N = structure_count
n = np.cumsum(np.array([0] + atom_counts, dtype=np.int64))
np.save(N_path, N)
np.save(n_path, n)

a_mm = np.memmap(a_path, dtype="int32", mode="w+", shape=(n[-1],))
x_mm = np.memmap(x_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: c_mm = np.memmap(c_path, dtype="float32", mode="w+", shape=(N, 3, 3))
e_mm = np.memmap(e_path, dtype="float32", mode="w+", shape=(N, 1))
f_mm = np.memmap(f_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: s_mm = np.memmap(s_path, dtype="float32", mode="w+", shape=(N, 3, 3))

i = 0
for dataset_path in all_aselmdb_files:
    try:
        print(f"Reading {dataset_path}")
        dataset = AseDBDataset(config=dict(src=dataset_path))
    except Exception as e:
        print(f"Failed to load dataset {dataset_path}: {e}")
        continue

    for index in tqdm.tqdm(range(len(dataset))):
        s = dataset.get_atoms(index)

        a_mm[n[i]:n[i+1]] = s.numbers
        x_mm[n[i]:n[i+1]] = s.get_positions()
        if do_cell_and_stress: c_mm[i] = s.get_cell()[:]
        e_mm[i] = s.get_potential_energy()
        f_mm[n[i]:n[i+1]] = s.get_forces()
        if do_cell_and_stress: s_mm[i] = s.get_stress(voigt=False)

        i += 1

a_mm.flush()
x_mm.flush()
if do_cell_and_stress: c_mm.flush()
e_mm.flush()
f_mm.flush()
if do_cell_and_stress: s_mm.flush()
