import torch

from .Abstract import *
from net.force_model import ForceModel

# MODEL_PATH = 'train/force/wasted/force-qm7.pkl'
MODEL_PATH = 'train/force/force-qm9.pkl'


class LangevinDerivator(Derivator):
    def __init__(self, *args, **kwargs):
        super(LangevinDerivator, self).__init__(*args, **kwargs)
        self.forces = [ForceModel(use_cuda=self.use_cuda)]
        try:
            self.forces[0].load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
        except FileNotFoundError or EOFError:
            assert False, f"Force pretrain model dict {MODEL_PATH} can't be loaded, please run force.slurm first."
        self.forces[0].eval()
        if self.use_cuda:
            self.forces[0].cuda()

    def forward(self, v: torch.Tensor, e: torch.Tensor, m: torch.Tensor, p: torch.Tensor, q: torch.Tensor,
                mask_matrices: MaskMatrices, return_list: List[str], **kwargs
                ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
        atom_ftr = kwargs['atom_ftr']
        bond_ftr = kwargs['bond_ftr']
        # dq / dt = v = p / m
        dq = p / m

        # dp / dt = F
        dp, _, _ = self.forces[0].forward(
            atom_ftr=atom_ftr,
            bond_ftr=bond_ftr,
            pos=q,
            mask_matrices=mask_matrices
        )

        return_dict = {}
        return dp, dq, return_dict
