import ase.io
from metatensor.torch.atomistic import systems_to_torch, load_atomistic_model
import torch
import metatensor.torch
import random
from metatensor.torch.atomistic import ModelEvaluationOptions, ModelOutput
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
import sys
from metatensor.torch import Labels, TensorBlock, TensorMap
import numpy as np
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.data import atomic_masses
import tqdm


device="cuda"
atomic_masses_tensor = torch.tensor(atomic_masses, device=device)

all_values = {}

model_string = f"../../models/water_1fs_300.pt"

values = []

for _ in tqdm.tqdm(range(100)):

    # read atoms from a trajectory
    atoms = ase.io.read("water.xyz")
    MaxwellBoltzmannDistribution(atoms, temperature_K=450.0)

    system = systems_to_torch(atoms, dtype=torch.float32, device=device, positions_requires_grad=True)
    system.add_data(
        "momenta",
        metatensor.torch.TensorMap(
            keys=metatensor.torch.Labels.single().to(device),
            blocks = [
                metatensor.torch.TensorBlock(
                    values=torch.tensor(atoms.get_momenta(), dtype=torch.float32, device=device, requires_grad=True).unsqueeze(-1),
                    samples=metatensor.torch.Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[0, j] for j in range(len(atoms))], device=device),
                    ),
                    components=[metatensor.torch.Labels(names="xyz", values=torch.tensor([[0], [1], [2]], device=device))],
                    properties=metatensor.torch.Labels.single().to(device),
                )
            ],
        )
    )
    system.add_data(
        "masses",
        metatensor.torch.TensorMap(
            keys=metatensor.torch.Labels.single().to(device),
            blocks = [
                metatensor.torch.TensorBlock(
                    values=torch.tensor(atoms.get_masses(), dtype=torch.float32, device=device).unsqueeze(-1),
                    samples=metatensor.torch.Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[0, j] for j in range(len(atoms))], device=device),
                    ),
                    components=[],
                    properties=metatensor.torch.Labels.single().to(device),
                )
            ],
        )
    )


    def get_random_degree_of_freedom(system):
        # returns a random number from 0 to n_atoms-1
        # and a random number from 0 to 2
        n_atoms = len(system.positions)
        n_dof = 3 * n_atoms
        random_dof = random.randint(0, n_dof - 1)
        # convert to atom index and component index
        atom_index = random_dof // 3
        component_index = random_dof % 3
        return atom_index, component_index

    model = load_atomistic_model(model_string)
    time_lag = 4
    time_lag_tmap = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor([[int(time_lag)]], dtype=torch.float32),
                samples=Labels.single(),
                components=[],
                properties=Labels.single(),
            )
        ],
    ).to(system.device)
    system.add_data("time_lag", time_lag_tmap)
    is_implicit = "implicit" in model_string
    model = model.to(device)
    if is_implicit:
        model.module.train(True)
    get_system_with_neighbor_lists(
        system,
        model.requested_neighbor_lists(),
    )
    # extract the time lag from the string
    evaluation_options = ModelEvaluationOptions(
        length_unit="Angstrom",
        outputs={
            f"mtt::delta_{time_lag}_q": ModelOutput(per_atom=True),
            f"mtt::p_{time_lag}": ModelOutput(per_atom=True),
        }
    )
    model_outputs = model([system], evaluation_options, check_consistency=False)
    delta_q = model_outputs[f"mtt::delta_{time_lag}_q"]
    masses = atomic_masses_tensor[system.types]
    P = model_outputs[f"mtt::p_{time_lag}"].block().values.squeeze(-1) * torch.sqrt(masses.unsqueeze(-1))
    Q = system.positions + delta_q.block().values.squeeze(-1) / torch.sqrt(masses.unsqueeze(-1))
    # get a random degree of freedom
    atom_index, component_index = get_random_degree_of_freedom(system)
    P_element = P[atom_index, component_index]
    Q_element = Q[atom_index, component_index]
    dPdp, dPdq = torch.autograd.grad(
        P_element,
        [system.get_data("momenta").block().values, system.positions],
        create_graph=False,
        retain_graph=True,
    )
    dQdp, dQdq = torch.autograd.grad(
        Q_element,
        [system.get_data("momenta").block().values, system.positions]
    )
    dPdp = dPdp.squeeze(-1)
    dPdq = dPdq.squeeze(-1)

    dPdq_element = dPdq[atom_index, component_index].item()
    dQdp_element = dQdp[atom_index, component_index].item()
    dQdq_element = dQdq[atom_index, component_index].item()
    dPdp_element = dPdp[atom_index, component_index].item()

    values.append(dQdq_element * dPdp_element - dPdq_element * dQdp_element)

values = np.array(values)
print(np.mean(values))
