import ase.io
from metatensor.torch.atomistic import systems_to_torch, load_atomistic_model
import torch
import metatensor.torch
import random
from metatensor.torch.atomistic import ModelEvaluationOptions, ModelOutput
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
import sys
from metatensor.torch import Labels, TensorBlock, TensorMap
import numpy as np
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.data import atomic_masses
import tqdm


device="cuda"
atomic_masses_tensor = torch.tensor(atomic_masses, device=device)

all_values = {}

for model_type in ["water", "universal"]:
    for time_step in [1, 4, 16]:
        print(f"{model_type}-{time_step}")
        model_string = f"../../models/{model_type}_{time_step}fs.pt"

        values = []

        for _ in tqdm.tqdm(range(100)):

            # read atoms from a trajectory
            atoms = ase.io.read("water.xyz")
            MaxwellBoltzmannDistribution(atoms, temperature_K=450.0)

            system = systems_to_torch(atoms, dtype=torch.float32, device=device, positions_requires_grad=True)
            system.add_data(
                "momenta",
                metatensor.torch.TensorMap(
                    keys=metatensor.torch.Labels.single().to(device),
                    blocks = [
                        metatensor.torch.TensorBlock(
                            values=torch.tensor(atoms.get_momenta(), dtype=torch.float32, device=device, requires_grad=True).unsqueeze(-1),
                            samples=metatensor.torch.Labels(
                                names=["system", "atom"],
                                values=torch.tensor([[0, j] for j in range(len(atoms))], device=device),
                            ),
                            components=[metatensor.torch.Labels(names="xyz", values=torch.tensor([[0], [1], [2]], device=device))],
                            properties=metatensor.torch.Labels.single().to(device),
                        )
                    ],
                )
            )
            system.add_data(
                "masses",
                metatensor.torch.TensorMap(
                    keys=metatensor.torch.Labels.single().to(device),
                    blocks = [
                        metatensor.torch.TensorBlock(
                            values=torch.tensor(atoms.get_masses(), dtype=torch.float32, device=device).unsqueeze(-1),
                            samples=metatensor.torch.Labels(
                                names=["system", "atom"],
                                values=torch.tensor([[0, j] for j in range(len(atoms))], device=device),
                            ),
                            components=[],
                            properties=metatensor.torch.Labels.single().to(device),
                        )
                    ],
                )
            )


            def get_random_degree_of_freedom(system):
                # returns a random number from 0 to n_atoms-1
                # and a random number from 0 to 2
                n_atoms = len(system.positions)
                n_dof = 3 * n_atoms
                random_dof = random.randint(0, n_dof - 1)
                # convert to atom index and component index
                atom_index = random_dof // 3
                component_index = random_dof % 3
                return atom_index, component_index

            model = load_atomistic_model(model_string)
            time_lag = int(int(model_string.split("_")[-1].split("fs")[0]) / model.module.base_time_step)
            time_lag_tmap = TensorMap(
                keys=Labels.single(),
                blocks=[
                    TensorBlock(
                        values=torch.tensor([[int(time_lag)]], dtype=torch.float32),
                        samples=Labels.single(),
                        components=[],
                        properties=Labels.single(),
                    )
                ],
            ).to(system.device)
            system.add_data("time_lag", time_lag_tmap)
            is_implicit = "implicit" in model_string
            model = model.to(device)
            if is_implicit:
                model.module.train(True)
            get_system_with_neighbor_lists(
                system,
                model.requested_neighbor_lists(),
            )
            # extract the time lag from the string
            evaluation_options = ModelEvaluationOptions(
                length_unit="Angstrom",
                outputs={
                    f"mtt::delta_{time_lag}_q": ModelOutput(per_atom=True),
                    f"mtt::p_{time_lag}": ModelOutput(per_atom=True),
                }
            )
            model_outputs = model([system], evaluation_options, check_consistency=False)
            delta_q = model_outputs[f"mtt::delta_{time_lag}_q"]
            masses = atomic_masses_tensor[system.types]
            P = model_outputs[f"mtt::p_{time_lag}"].block().values.squeeze(-1) * torch.sqrt(masses.unsqueeze(-1))
            Q = system.positions + delta_q.block().values.squeeze(-1) / torch.sqrt(masses.unsqueeze(-1))
            # get a random degree of freedom
            atom_index, component_index = get_random_degree_of_freedom(system)
            P_element = P[atom_index, component_index]
            Q_element = Q[atom_index, component_index]
            dPdp, dPdq = torch.autograd.grad(
                P_element,
                [system.get_data("momenta").block().values, system.positions],
                create_graph=False,
                retain_graph=True,
            )
            dQdp, dQdq = torch.autograd.grad(
                Q_element,
                [system.get_data("momenta").block().values, system.positions]
            )
            dPdp = dPdp.squeeze(-1)
            dPdq = dPdq.squeeze(-1)

            dPdq_element = dPdq[atom_index, component_index].item()
            dQdp_element = dQdp[atom_index, component_index].item()
            dQdq_element = dQdq[atom_index, component_index].item()
            dPdp_element = dPdp[atom_index, component_index].item()

            values.append(dQdq_element * dPdp_element - dPdq_element * dQdp_element)

        values = np.array(values)
        print(np.mean(values))
        all_values[f"{model_type}-{time_step}"] = np.mean(values)

print(all_values)
