import torch
from ase.io import write
from ase.calculators.singlepoint import SinglePointCalculator
import sys
from ase.io import read
from tqdm import tqdm
import os
from metatomic.torch.ase_calculator import MetatomicCalculator
from metatomic.torch import load_atomistic_model
from pet_mad.utils import get_so3_rotations, rotate_atoms, compute_rotational_average
from typing import Dict, Any

LEBEDEV_GRID_ORDER = 5
NUM_PRIMITIVE_ROTATIONS = 12
BATCH_SIZE = 8
DTYPE = "float64"


if __name__ == "__main__":
    model_path = sys.argv[1]
    dataset_path = sys.argv[2]
    dataset_name = os.path.basename(dataset_path)
    workdir = sys.argv[3]
    atoms = read(dataset_path, index=":")
    if not os.path.exists(workdir):
        os.makedirs(workdir)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    non_conservative = "nc" in model_path.lower()

    model = load_atomistic_model(model_path)
    if DTYPE == "float64":
        model.capabilities().dtype = "float64"
        model.to(dtype=torch.float64, device=device)
    calculator = MetatomicCalculator(
        model, device=device, non_conservative=non_conservative
    )

    print(
        f"Running {model_path} on {dataset_name} with non_conservative={non_conservative}"
    )

    rotations = get_so3_rotations(
        LEBEDEV_GRID_ORDER,
        NUM_PRIMITIVE_ROTATIONS,
    )

    predictions = []
    for item in tqdm(atoms):
        rotated_atoms_list = rotate_atoms(item, rotations)
        batches = [
            rotated_atoms_list[i : i + BATCH_SIZE]
            for i in range(0, len(rotated_atoms_list), BATCH_SIZE)
        ]

        results: Dict[str, Any] = {}
        for batch in batches:
            batch_results = calculator.compute_energy(batch, True)
            for key, value in batch_results.items():
                results.setdefault(key, [])
                results[key].extend([value] if isinstance(value, float) else value)
        results = compute_rotational_average(results, rotations)

        calc = SinglePointCalculator(
            energy=results["energy"], forces=results["forces"], atoms=item
        )
        new_item = item.copy()
        new_item.info["energy_rot_discrepancy"] = results["energy_rot_std"]
        new_item.info["forces_rot_discrepancy"] = results["forces_rot_std"]
        new_item.calc = calc
        predictions.append(new_item)

    filename = "predictions_float64.xyz" if DTYPE == "float64" else "predictions.xyz"
    predictions_path = os.path.join(workdir, filename)
    write(predictions_path, predictions)
    print(f"Predictions written to {predictions_path}")
