import torch
import numpy as np
from openmm import *
from openmm.app import *
from openmm.unit import *


class TorchEnergyFunction:
    def __init__(self, pdbfile, forcefield):
        pdb = PDBFile(pdbfile)
        forcefield = ForceField(forcefield)
        system = forcefield.createSystem(pdb.topology, constraints=None)

        fs = system.getForces()
        Num_parts = system.getNumParticles()
        I = np.identity(Num_parts)

        min_const = torch.Tensor(np.array([1e-5]))

        def norm(x, dim=None, keepdims=False):
            norm_val = torch.max(torch.linalg.norm(x, dim=dim, keepdims=keepdims), min_const)
            return x/norm_val
        
        for force in fs:
            if isinstance(force, HarmonicBondForce):
                HBF = force
            elif isinstance(force, HarmonicAngleForce):
                HAF = force
            elif isinstance(force, PeriodicTorsionForce):
                PTF = force
            elif isinstance(force, NonbondedForce):
                NBF = force
            else:
                pass

        HBF_params = [HBF.getBondParameters(i) for i in range(HBF.getNumBonds())]
        HB_adj = torch.Tensor(I[[p[0] for p in HBF_params], :] - I[[p[1] for p in HBF_params], :])
        HB_dist = torch.Tensor(np.array([p[2].value_in_unit(p[2].unit) for p in HBF_params]).reshape(-1,1))
        HB_k = torch.Tensor(np.array([p[3].value_in_unit(p[3].unit) for p in HBF_params]).reshape(-1,1))

        def calc_HB(pos, interim=False):
            bond_dists = torch.sqrt(torch.sum((HB_adj@pos)**2, dim=1, keepdims=True))
            energy = 0.5*HB_k.T@((bond_dists - HB_dist)**2)
            if interim:
                return energy[0, 0], bond_dists
            else:
                return energy[0, 0]
        

        HAF_params = [HAF.getAngleParameters(i) for i in range(HAF.getNumAngles())]
        HA_adj_1 = torch.Tensor(I[[p[0] for p in HAF_params], :] - I[[p[1] for p in HAF_params], :])
        HA_adj_2 = torch.Tensor(I[[p[2] for p in HAF_params], :] - I[[p[1] for p in HAF_params], :])
        HA_theta = torch.Tensor(np.array([p[3].value_in_unit(p[3].unit) for p in HAF_params]).reshape(-1,1))
        HA_k = torch.Tensor(np.array([p[4].value_in_unit(p[4].unit) for p in HAF_params]).reshape(-1,1))

        def calc_HA(pos, interim=False):
            
            r1 = norm(HA_adj_1@pos, dim=1, keepdims=True)
            r2 = norm(HA_adj_2@pos, dim=1, keepdims=True)

            bond_theta = torch.arccos(torch.sum(r1*r2, dim=1)).reshape(-1,1)
            energy = 0.5*HA_k.T@((bond_theta - HA_theta)**2)

            if interim:
                return energy[0,0], bond_theta
            else:
                return energy[0,0]


        PTF_params = [PTF.getTorsionParameters(i) for i in range(PTF.getNumTorsions())]

        PT_adj_1 = torch.Tensor(I[[p[1] for p in PTF_params], :] - I[[p[0] for p in PTF_params], :])
        PT_adj_2 = torch.Tensor(I[[p[2] for p in PTF_params], :] - I[[p[1] for p in PTF_params], :])
        PT_adj_3 = torch.Tensor(I[[p[3] for p in PTF_params], :] - I[[p[2] for p in PTF_params], :])
        PT_freq = torch.Tensor(np.array([p[4] for p in PTF_params]).reshape(-1,1))
        PT_theta = torch.Tensor(np.array([p[5].value_in_unit(p[5].unit) for p in PTF_params]).reshape(-1,1))
        PT_k = torch.Tensor(np.array([p[6].value_in_unit(p[6].unit) for p in PTF_params]).reshape(-1,1))

        def calc_PT(pos, interim=False):
            r1 = PT_adj_1@pos
            r2 = PT_adj_2@pos
            r3 = PT_adj_3@pos

            v1 = norm(torch.cross(r1, r2), dim=1, keepdims=True)
            v2 = norm(torch.cross(r2, r3), dim=1, keepdims=True)
            
            bond_theta = (torch.sign(torch.sum(v1*r3, dim=1))*torch.arccos(torch.sum(v1*v2, dim=1).clamp(-1.0+min_const, 1.0-min_const))).reshape(-1,1)
            energy = PT_k.T@(1 + torch.cos(PT_freq*bond_theta - PT_theta))

            if interim:
                return energy[0,0], bond_theta
            else:
                return energy[0,0]

        NBF_num = NBF.getNumParticles()
        NBF_params = [NBF.getParticleParameters(i) for i in range(NBF_num)]
        NBF_excep_num = NBF.getNumExceptions()
        NBF_excep_params = [NBF.getExceptionParameters(i) for i in range(NBF_excep_num)]

        Qs = np.array([p[0].value_in_unit(p[0].unit) for p in NBF_params]).reshape(-1,1)
        NB_Q = Qs@Qs.T

        Sigs = np.array([p[1].value_in_unit(p[1].unit) for p in NBF_params]).reshape(-1,1)
        NB_Sig = 0.5*(Sigs + Sigs.T)
        Es = np.array([p[2].value_in_unit(p[2].unit) for p in NBF_params]).reshape(-1,1)
        NB_Eps = np.sqrt(Es@Es.T)

        for ind1, ind2, qs, sig, eps in NBF_excep_params:
            NB_Q[ind1, ind2], NB_Q[ind2, ind1] = qs.value_in_unit(qs.unit), qs.value_in_unit(qs.unit)
            NB_Sig[ind1, ind2], NB_Sig[ind2, ind1] = sig.value_in_unit(sig.unit), sig.value_in_unit(sig.unit)
            NB_Eps[ind1, ind2], NB_Eps[ind2, ind1] = eps.value_in_unit(eps.unit), eps.value_in_unit(eps.unit)
        for i in range(Num_parts):
            NB_Q[i,i], NB_Sig[i,i], NB_Eps[i,i] = 0.0, 0.0, 0.0
        
        NB_Q = torch.Tensor(NB_Q)
        NB_Sig = torch.Tensor(NB_Sig)
        NB_Eps = torch.Tensor(NB_Eps)

            
        def LJ(x):
            val = x**12 - x**6 
            return torch.sign(val)*torch.abs(val)
        
        def calc_NB(pos):
            R = torch.sqrt(torch.sum((pos[:, None, :] - pos[None, :, :] + torch.Tensor(I)[:,:,None])**2, dim=2)) 
            energy =  ((1745.81796678/(4*torch.pi))*torch.sum(NB_Q/R)) + 4*torch.sum(NB_Eps*LJ(NB_Sig/R))
            return energy/2

        self.forces = {"HB": calc_HB, "HA": calc_HA, "PT": calc_PT, "NB": calc_NB}

    def total_energy(self, pos):
        return sum([force(pos) for force in self.forces.values()])

if __name__ == "__main__":
    pdbname = 'alanine-dipeptide.pdb'
    forcefieldname = "amber14-all.xml"

    energy = TorchEnergyFunction(pdbname, forcefieldname)
    
    pdb = PDBFile(pdbname)
    forcefield = ForceField(forcefieldname)
    system = forcefield.createSystem(pdb.topology, constraints=None)
    integrator = LangevinIntegrator(300 * kelvin, 1 / picosecond, 2 * femtoseconds)
    simulation = Simulation(pdb.topology, system, integrator)
    simulation.context.setPositions(pdb.positions)
    simulation.minimizeEnergy()

    state = simulation.context.getState(getPositions=True, getForces=True, getEnergy=True)
    position, g_true, energy_true = state.getPositions(asNumpy=True), state.getForces(asNumpy=True), state.getPotentialEnergy()
    position, g_true, energy_true = position.value_in_unit(position.unit), g_true.value_in_unit(g_true.unit), energy_true.value_in_unit(energy_true.unit)

    print(energy_true, energy.total_energy(torch.Tensor(position)))