from ase.io import Trajectory
import numpy as np
import torch
import tqdm
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import NeighborListOptions, systems_to_torch
from metatrain.utils.data import DiskDatasetWriter
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists
import copy
import sys


time_lags = [int(sys.argv[1])]
disk_dataset_writers = {time_lag: DiskDatasetWriter(f"universal_{time_lag}.zip") for time_lag in time_lags}
correlation_time = 2000  # 500 fs

def write_to_dataset(frame_now, frame_ahead, time_lag, i, disk_dataset_writer):
    system = systems_to_torch(frame_now, dtype=torch.float64)
    system = get_system_with_neighbor_lists(
        system,
        [NeighborListOptions(cutoff=5.0, full_list=True, strict=True)],
    )
    system.add_data(
        "momenta",
        TensorMap(
            keys=Labels.single(),
            blocks=[
                TensorBlock(
                    values=torch.tensor(frame_now.get_momenta(), dtype=torch.float64).unsqueeze(-1),
                    samples=Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                    ),
                    components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                    properties=Labels.single(),
                )
            ],
        )
    )
    masses = frame_now.get_masses()[:, np.newaxis]
    system.add_data(
        "masses",
        TensorMap(
            keys=Labels.single(),
            blocks = [
                TensorBlock(
                    values=torch.tensor(masses, dtype=torch.float64),
                    samples=Labels(
                        names=["system", "atom"],
                        values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                    ),
                    components=[],
                    properties=Labels.single(),
                )
            ],
        )
    )

    distances = frame_ahead.get_positions() - frame_now.get_positions()
    if np.any(np.abs(distances) > 10.0 * 0.25 * time_lag):
        # cut anything over 10 angstrom/fs speed, these have to be wrong
        return False
    delta_q = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor(distances*np.sqrt(masses), dtype=torch.float64).unsqueeze(-1),
                samples=Labels(
                    names=["system", "atom"],
                    values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                ),
                components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                properties=Labels.single(),
            )
        ],
    )
    p_prime = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor(frame_ahead.get_momenta()/np.sqrt(masses), dtype=torch.float64).unsqueeze(-1),
                samples=Labels(
                    names=["system", "atom"],
                    values=torch.tensor([[i, j] for j in range(len(frame_now))]),
                ),
                components=[Labels(names="xyz", values=torch.tensor([[0], [1], [2]]))],
                properties=Labels.single(),
            )
        ],
    )
    energy = TensorMap(
        keys=Labels.single(),
        blocks=[
            TensorBlock(
                values=torch.tensor([frame_ahead.info["total_energy"]], dtype=torch.float64).unsqueeze(-1),
                samples=Labels(names=["system"], values=torch.tensor([[i]])),
                components=[],
                properties=Labels.single(),
            )
        ],
    )
    disk_dataset_writer.write_sample(
        system,
        {
            f"mtt::delta_{time_lag}_q": delta_q,
            f"mtt::p_{time_lag}": p_prime,
            f"mtt::energy_{time_lag}": energy
        }
    )
    return True

assert len(time_lags) == 1, "Only one time lag is supported at the moment"
# see structure counting below

structure_counter = 0
for trj_num in range(10000):
    for temp in range(10):
        print(trj_num, temp)
        try:
            traj = Trajectory(f'/work/anon/anonymous/skipmd_fine_ts_trajs/str{trj_num}/str{trj_num}_NVE_{temp}.traj')
        except:
            print(f"Trajectory {trj_num} at temperature {temp} not found.")
            continue
        traj_len = len(traj)
        for i in tqdm.tqdm(range(0, traj_len-max(time_lags), correlation_time)):
            for time_lag in time_lags:
                frame_now = traj[i]
                frame_ahead = traj[i+time_lag]
                written = write_to_dataset(frame_now, frame_ahead, time_lag, structure_counter, disk_dataset_writers[time_lag])
                if written:
                    structure_counter += 1
                else:
                    print("Not written")
                frame_now_trev = copy.deepcopy(frame_now)
                frame_ahead_trev = copy.deepcopy(frame_ahead)
                frame_now_trev.set_momenta(-frame_now_trev.get_momenta())
                frame_ahead_trev.set_momenta(-frame_ahead_trev.get_momenta())
                written = write_to_dataset(frame_ahead_trev, frame_now_trev, time_lag, structure_counter+1, disk_dataset_writers[time_lag])
                if written:
                    structure_counter += 1
                else:
                    print("Not written")

for k in list(disk_dataset_writers.keys()):
    disk_dataset_writer = disk_dataset_writers.pop(k)
    del disk_dataset_writer
