import numpy as np
from pathlib import Path

import ase.io
import tqdm

from fairchem.core.datasets import AseDBDataset
import glob
import tqdm

do_cell_and_stress = True

all_aselmdb_files = glob.glob("../salex/train/*.aselmdb")

structure_count = 0
atom_counts = []
for dataset_path in all_aselmdb_files:
    try:
        print(f"Reading {dataset_path}")
        dataset = AseDBDataset(config=dict(src=dataset_path))
    except Exception as e:
        print(f"Failed to load dataset {dataset_path}: {e}")
        continue

    for index in range(len(dataset)):
        atoms = dataset.get_atoms(index)
        structure_count += 1
        atom_counts.append(len(atoms))

structures_mptrj = []
all_files = glob.glob("../mptrj/mptrj-gga-ggapu/*.extxyz")
for file in tqdm.tqdm(all_files):
    atoms = ase.io.read(file, index=":")
    structures_mptrj.extend(atoms)

root = Path(f"mpa_train_mm")
root.mkdir(exist_ok=True)

N_path = root / "N.npy"
n_path = root / "n.npy"
a_path = root / "a.bin"
x_path = root / "x.bin"
c_path = root / "c.bin"
e_path = root / "e.bin"
f_path = root / "f.bin"
s_path = root / "s.bin"

N = structure_count + 8 * len(structures_mptrj)
n = np.cumsum(np.array([0] + atom_counts + 8 * [len(s) for s in structures_mptrj], dtype=np.int64))
np.save(N_path, N)
np.save(n_path, n)

a_mm = np.memmap(a_path, dtype="int32", mode="w+", shape=(n[-1],))
x_mm = np.memmap(x_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: c_mm = np.memmap(c_path, dtype="float32", mode="w+", shape=(N, 3, 3))
e_mm = np.memmap(e_path, dtype="float32", mode="w+", shape=(N, 1))
f_mm = np.memmap(f_path, dtype="float32", mode="w+", shape=(n[-1], 3))
if do_cell_and_stress: s_mm = np.memmap(s_path, dtype="float32", mode="w+", shape=(N, 3, 3))

i = 0
for dataset_path in all_aselmdb_files:
    try:
        print(f"Reading {dataset_path}")
        dataset = AseDBDataset(config=dict(src=dataset_path))
    except Exception as e:
        print(f"Failed to load dataset {dataset_path}: {e}")
        continue

    for index in tqdm.tqdm(range(len(dataset))):
        s = dataset.get_atoms(index)

        a_mm[n[i]:n[i+1]] = s.numbers
        x_mm[n[i]:n[i+1]] = s.get_positions()
        if do_cell_and_stress: c_mm[i] = s.get_cell()[:]
        e_mm[i] = s.get_potential_energy()
        f_mm[n[i]:n[i+1]] = s.get_forces()
        if do_cell_and_stress: s_mm[i] = s.get_stress(voigt=False)
        i += 1

for replica in range(8):
    print(f"Processing replica {replica+1}/8 of mptrj data")
    for s in tqdm.tqdm(structures_mptrj):
        a_mm[n[i]:n[i+1]] = s.numbers
        x_mm[n[i]:n[i+1]] = s.get_positions()
        if do_cell_and_stress: c_mm[i] = s.get_cell()[:]
        e_mm[i] = s.get_potential_energy()
        f_mm[n[i]:n[i+1]] = s.get_forces()
        if do_cell_and_stress: s_mm[i] = s.get_stress(voigt=False)
        i += 1

a_mm.flush()
x_mm.flush()
if do_cell_and_stress: c_mm.flush()
e_mm.flush()
f_mm.flush()
if do_cell_and_stress: s_mm.flush()
