from .velocity_verlet import VelocityVerlet
import ase.units
from ase.md.md import MolecularDynamics
from typing import List
# from ..utils.pretrained import load_pretrained_models
from metatensor.torch.atomistic import MetatensorAtomisticModel
from metatensor.torch import Labels, TensorBlock, TensorMap
import ase.units
import torch
from metatensor.torch.atomistic.ase_calculator import _ase_to_torch_data
from metatensor.torch.atomistic import System
import ase
from typing import Optional
import numpy as np


class Langevin(VelocityVerlet):
    def __init__(
        self,
        atoms: ase.Atoms,
        timestep: float,
        temperature_K: float,
        model: MetatensorAtomisticModel | List[MetatensorAtomisticModel],
        time_constant: float = 100.0 * ase.units.fs,
        device: str | torch.device = "auto",
        rescale_energy: bool = True,
        **kwargs
    ):
        super().__init__(atoms, timestep, model, device, rescale_energy, **kwargs)

        self.temperature_K = temperature_K
        self.friction = 1.0 / time_constant

    def step(self):
        self.apply_langevin_half_step()
        super().step()
        self.apply_langevin_half_step()

    def apply_langevin_half_step(self):
        old_momenta = self.atoms.get_momenta()
        new_momenta = (
            np.exp(-self.friction * 0.5 * self.dt) *
            old_momenta +
            np.sqrt(1.0 - np.exp(-self.friction * self.dt)) *
            np.sqrt(ase.units.kB * self.temperature_K * self.atoms.get_masses()[:, None]) *
            np.random.randn(*old_momenta.shape)
        )
        self.atoms.set_momenta(new_momenta)

