from ase import Atoms
from ase.io import Trajectory
import numpy as np
import ase.io
import torch
import tqdm
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch
from metatrain.utils.data import DiskDatasetWriter
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
import copy


time_lags = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
disk_dataset_writers = {time_lag: DiskDatasetWriter(f"water_scaled_{time_lag}.zip") for time_lag in time_lags}
correlation_time = 400

def write_to_dataset(frame_now, frame_ahead, time_lag, i, disk_dataset_writer):
    system = systems_to_torch(frame_now, dtype=torch.float64)
    system = get_system_with_neighbor_lists(
        system,
        [NeighborListOptions(cutoff=5.0, full_list=True, strict=True)],
    )
    system.add_data(
        "momenta",
        TensorMap(
            keys=Labels.single(),
            blocks=[
                TensorBlock(
                    values=torch.tensor(frame_now.get_momenta(), dtype=torch.float64).unsqueeze(-1),
                    samples=Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                    ),
                    components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                    properties=Labels.single(),
                )
            ],
        )
    )
    masses = frame_now.get_masses()[:, np.newaxis]
    system.add_data(
        "masses",
        TensorMap(
            keys=Labels.single(),
            blocks = [
                TensorBlock(
                    values=torch.tensor(masses, dtype=torch.float64),
                    samples=Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                    ),
                    components=[],
                    properties=Labels.single(),
                )
            ],
        )
    )

    distances = frame_ahead.get_positions() - frame_now.get_positions()
    if np.any(np.abs(distances) > 1.0 * time_lag / 10.0):
        print(f"WARNING: distance between atoms is too large: {np.max(np.abs(distances))} A, {time_lag} fs")
    delta_q = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor(distances*np.sqrt(masses), dtype=torch.float64).unsqueeze(-1),
                samples=Labels(
                    names=["system", "atom"],
                    values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                ),
                components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                properties=Labels.single(),
            )
        ],
    )
    p_prime = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor(frame_ahead.get_momenta()/np.sqrt(masses), dtype=torch.float64).unsqueeze(-1),
                samples=Labels(
                    names=["system", "atom"],
                    values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                ),
                components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                properties=Labels.single(),
            )
        ],
    )
    energy = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor([frame_ahead.info["total_energy"]], dtype=torch.float64).unsqueeze(-1),
                samples=Labels(names=["system"], values=torch.tensor([[i]])),
                components=[],
                properties=Labels.single(),
            )
        ],
    )

    masses = torch.tensor(masses, dtype=torch.float64)
    p = system.get_data("momenta").block().values.squeeze(-1)
    p_prime = p_prime.block().values.squeeze(-1)
    delta_q = delta_q.block().values.squeeze(-1)
    delta_p = p_prime * torch.sqrt(masses) - p
    # print(torch.mean(delta_p, dim=0))

    print(torch.mean(delta_q * torch.sqrt(masses), dim=0))

    # disk_dataset_writer.write_sample(
    #     system,
    #     {
    #         f"mtt::delta_{time_lag}_q": delta_q,
    #         f"mtt::p_{time_lag}": p_prime,
    #         f"mtt::energy_{time_lag}": energy
    #     }
    # )

structure_counter = 0
for size in ["smaller", "normal", "larger"]:
    for temperature in [20.0 * (i+1) for i in range(0, 50)]:
        print(size, temperature)
        if size == "normal":
            traj = ase.io.read(f"water_{temperature}.traj", index=":")
        else:
            traj = ase.io.read(f"water_{size}_{temperature}.traj", index=":")
        traj_len = len(traj)
        for i in tqdm.tqdm(range(0, traj_len-max(time_lags), correlation_time)):
            for time_lag in time_lags:
                frame_now = traj[i]
                frame_ahead = traj[i+time_lag]
                write_to_dataset(frame_now, frame_ahead, time_lag, structure_counter, disk_dataset_writers[time_lag])
                frame_now_trev = copy.deepcopy(frame_now)
                frame_ahead_trev = copy.deepcopy(frame_ahead)
                frame_now_trev.set_momenta(-frame_now_trev.get_momenta())
                frame_ahead_trev.set_momenta(-frame_ahead_trev.get_momenta())
                write_to_dataset(frame_ahead_trev, frame_now_trev, time_lag, structure_counter+1, disk_dataset_writers[time_lag])
            structure_counter += 2

for k in list(disk_dataset_writers.keys()):
    disk_dataset_writer = disk_dataset_writers.pop(k)
    del disk_dataset_writer
