import os
import itertools
from utils.energy_function import TorchEnergyFunction
import torch
import numpy as np
from openmm import *
from openmm.app import *
from openmm.unit import *

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm

from abc import ABC, abstractmethod

def spectrum_abs(L):
    """assuming Hermitian"""
    v,p = torch.linalg.eigh(L)
    # ensure they're sorted from small to large v
    i = torch.argsort(v.abs()) # may not be PSD
    v = v[i]
    p = p[:,i]
    return v,p

Laplacian = lambda A: torch.diag(A.sum(1))-A

def get_Ls(H):
    n, d = H.shape[0], H.shape[1]

    M0 = sum([H[:, i, :, i]for i in range(d)])
    M1 = np.einsum('ij,kl-> jilk', M0, M0) 
    M2 = sum([np.einsum('ij,kl-> jkli', H[:, i, :, j], H[:, j, :, i]) for i in range(d) for j in range(d)])
    M3 = np.einsum('ij,kl-> ikjl', sum([H[:, i, :, j].T@H[:, i, :, j] for i in range(d) for j in range(d)]), np.eye(n))
    M = (M1+M2+M3).reshape((n**2, n**2))

    eigs, eigv = spectrum_abs(torch.Tensor(M))
    return eigs, eigv.reshape(eigv.shape[0], n, n)

def get_eigs(H):
    eigs, eigv = [], []
    n, d = H.shape[0], H.shape[1]
    A = sum([H[:, i, :, j]@H[:, i, :, j] for i in range(d) for j in range(d)])
    eigs, eigv = spectrum_abs(torch.Tensor(A))
    return eigs, eigv

def basis_elem(dim, i, j):
    M = np.zeros((dim,dim))
    M[i,j] = 1.0
    M[j,i] = -1.0
    return torch.Tensor(M)

class DOF(ABC):
    def __init__(self, pdbname, forcefieldname):
        self.name_mod = pdbname.split("/")[-1].split(".")[0]
        self.energy = TorchEnergyFunction(pdbname, forcefieldname)
        self.total_energy = self.energy.total_energy
        self.pdb = PDBFile(pdbname)
        self.forcefield = ForceField(forcefieldname)
        self.results_dir = f'./results/min_op/{self.mutation_type}/{self.name_mod}'

    def init_sim(self):
        system = self.forcefield.createSystem(self.pdb.topology, constraints=None)
        self.Num_parts = system.getNumParticles()
        integrator = LangevinIntegrator(300 * kelvin, 1 / picosecond, 2 * femtoseconds)
        
        simulation = Simulation(self.pdb.topology, system, integrator)
        simulation.context.setPositions(self.pdb.positions)
        simulation.minimizeEnergy()

        state = simulation.context.getState(getPositions=True)
        position = state.getPositions(asNumpy=True)
        position = position.value_in_unit(position.unit)

        init_position = torch.Tensor(position)
        pos_var = torch.autograd.Variable(torch.Tensor(position), requires_grad=True)
        full_Hessian = torch.autograd.functional.hessian(self.total_energy, pos_var)
        del simulation

        return init_position, full_Hessian

    def discover_DOF(self):
        os.makedirs(self.results_dir, exist_ok=True)

        self.init_position, full_Hessian = self.init_sim()
        eigs_P, eigv_P = get_eigs(full_Hessian.detach().numpy())
      
        eigs_P_normalized = (eigs_P/torch.amax(eigs_P)).reshape(-1)
        
        plt.figure(figsize=(6,4))
        eigvals = np.array(list(np.log(eigs_P.reshape(-1))[1:]))
        plt.hist(eigvals, bins=60)
        plt.yticks(np.arange(10))
        plt.xlabel('Log of EigenValue')
        plt.ylabel('Frequency of Occurence')
        plt.tight_layout()
        plt.savefig(f"{self.name_mod}_{self.mutation_type}_spectrum.png", dpi=300)

        groups = []
        group = [1]
        cur_val = eigs_P_normalized[1]
        for idx, val in enumerate(eigs_P_normalized[2:]):
            if abs(cur_val - val) > 5e-3:
                if len(group) > 1:
                    groups.append(group)
                group = [idx+2]
                cur_val = val
            else:
                group.append(idx+2)
        groups = groups[:4]
        print(groups) 

        data = [eigv_P[:,group[0]:(group[-1]+1)] for group in groups]
        for group in groups:
            np.savez(f"{self.results_dir}/eigenvectors", *data)
        
        self.val_groups = [eigs_P_normalized[group[0]:(group[-1]+1)] for group in groups]
        self.eig_groups = data

    @abstractmethod
    def get_basis(self, *args, **kwargs):
        pass

    def oneD_trajectory(self, eta, steps, Ms, eigs, init_pos, f, dim=1):
        M, newMs = Ms[0], Ms[1:]
        data = []

        def get_direction(pos, eta):
            return pos + eta*eigs@M@eigs.T@pos
        
        pos = init_pos
        for step in range(steps):
            if dim == 1:
                PDBFile.writeModel(self.modeller.topology, 10*pos.numpy(),  file=f, modelIndex=step)
                energy_val = self.energy.forces['HB'](pos, interim=True)[0].numpy().reshape((-1,1))
                bonds = (self.energy.forces['HB'](pos, interim=True)[1]).numpy().reshape((1,-1))
                angles = (self.energy.forces['HA'](pos, interim=True)[1]).numpy().reshape((1,-1))
                torsions = (self.energy.forces['PT'](pos, interim=True)[1]).numpy().reshape((1,-1))
                data.append(np.concatenate([energy_val, bonds, angles, torsions], axis=1).reshape(-1))
            else:
                data.append(self.oneD_trajectory(eta, steps, newMs, eigs, pos, f, dim=dim-1))
            pos = get_direction(pos, eta)
        data = np.stack(data, axis=0)
        return data

    def make_trajectories(self, *args, dim=1, divs=1000, **kwargs):
        divs = divs
        steps = int(np.exp(np.log(divs)/dim))
        eta = 2*np.pi/steps

        self.modeller = Modeller(self.pdb.topology, self.init_position)

        for idx, eigs in enumerate(self.eig_groups):
            Basis = self.get_basis(self.val_groups[idx], eigs, idx, *args, **kwargs) 
            print(len(Basis))

            for basis_idx, Ms in enumerate(itertools.combinations(Basis, dim)):
                with open(f'{self.results_dir}/eigen_{idx+1}_{basis_idx+1}.pdb', 'w') as f:
                    PDBFile.writeHeader(self.modeller.topology, f)
                    data = self.oneD_trajectory(eta, steps, list(Ms), eigs, self.init_position, f, dim=dim)
                    PDBFile.writeFooter(self.modeller.topology, f)
                    np.save(f"{self.results_dir}/traj_{idx+1}_{basis_idx+1}.npy", data)

class Lie(DOF):
    def __init__(self, pdbname, forcefieldname):
        self.mutation_type = "Lie"
        super().__init__(pdbname, forcefieldname)

    def get_basis(self, vals, eigs, idx, *args, top_K=2, **kwargs):
        # proj_dim = eigs.shape[1]
        # Sigma = torch.autograd.Variable(torch.zeros((proj_dim, proj_dim)), requires_grad=True)
        # if idx == 0:
        #     transform = lambda x : torch.Tensor(x)
        # else:
        #     transform = lambda x : torch.Tensor(x - x.T)
        # def shifted_energy(M):
        #     Proj_mat = eigs@transform(M)@eigs.T
        #     return self.total_energy(self.init_position + Proj_mat@self.init_position)
        # H = torch.autograd.functional.hessian(shifted_energy, Sigma).reshape((proj_dim*proj_dim,proj_dim*proj_dim))
        proj_dim = eigs.shape[1]
        v0 = (eigs.T@self.init_position)
        Lambda = torch.diag(vals**2)
        offset = 1

        if idx == 0:
            basis = torch.stack([basis_elem(proj_dim, i, j) if i==j  else (basis_elem(proj_dim, i, j) - basis_elem(proj_dim, j, i))/2 for i in range(proj_dim) for j in range(proj_dim)])
        else:
            if proj_dim == 2:
                return [basis_elem(proj_dim, 0, 1)]
            basis = torch.stack([basis_elem(proj_dim, i, j) for i in range(proj_dim) for j in range(i+1, proj_dim)])
        M = torch.zeros(basis.shape[0], basis.shape[0])
        transform = lambda x : torch.tensordot(x, basis, dims=1)
        for i in range(v0.shape[1]):
            Vi = torch.tensordot(basis, v0[:, i], dims=1)
            # for j in range(v0.shape[1]):
            #     Vj = torch.tensordot(basis, v0[:, j], dims=1)
            M += Vi@Lambda@(Vi.T)
        
        _, eigv_R = spectrum_abs(M)
        print(eigv_R[:, -top_K:])

        return [transform(eigv_R[:,-i]) for i in range(offset, top_K+offset)]

class CannonicalLie(DOF):
    def __init__(self, pdbname, forcefieldname):
        self.mutation_type = "CannonicalLie"
        super().__init__(pdbname, forcefieldname)
    
    def get_basis(self, vals, eigs, *args, **kwargs):
        proj_dim = eigs.shape[1]
        return [basis_elem(proj_dim, i, j) for i in range(proj_dim) for j in range(i+1, proj_dim)]
    
class TrueDOF(DOF):
    def __init__(self, pdbname, forcefieldname):
        self.mutation_type = "TrueLie"
        super().__init__(pdbname, forcefieldname)

    def discover_DOF(self):
        os.makedirs(self.results_dir, exist_ok=True)

        self.init_position, full_Hessian = self.init_sim()
        n,d = self.init_position.shape
        eigs_P, eigv_P = get_Ls(full_Hessian.detach().numpy())
        nB = 40
        self.eig_groups = [torch.Tensor(eigv_P[:nB,:].reshape(nB, n, n))]
        self.hessian = full_Hessian.reshape((n*d, n*d))

    def get_basis(self, hessian, Ls, *args, top_K=2, **kwargs):
        n,d = self.init_position.shape
        Lambda = (hessian.T@hessian).reshape((n,d,n,d))
        basis = Ls
        M = torch.zeros(basis.shape[0], basis.shape[0])
        transform = lambda x : torch.tensordot(x, basis, dims=1)
        x0 = torch.Tensor(self.init_position)
        for i in range(x0.shape[1]):
            Vi = torch.tensordot(basis, x0[:,i], dims=1)
            for j in range(x0.shape[1]):
                Vj = torch.tensordot(basis, x0[:, j], dims=1)
                M += Vi@Lambda[:,i,:,j]@(Vj.T)
        
        _, eigv_R = spectrum_abs(M)

        return [transform(eigv_R[:,-i]) for i in range(1, top_K+1)]
    
    def oneD_trajectory(self, eta, steps, Ms, init_pos, f, dim=1):
        M, newMs = Ms[0], Ms[1:]
        data = []

        def get_direction(pos, eta):
            return pos + eta*M@pos
        
        pos = torch.Tensor(init_pos)
        for step in range(steps):
            if dim == 1:
                PDBFile.writeModel(self.modeller.topology, 10*pos.numpy(),  file=f, modelIndex=step)
                energy_val = self.energy.forces['HB'](pos, interim=True)[0].numpy().reshape((-1,1))
                bonds = (self.energy.forces['HB'](pos, interim=True)[1]).numpy().reshape((1,-1))
                angles = (self.energy.forces['HA'](pos, interim=True)[1]).numpy().reshape((1,-1))
                torsions = (self.energy.forces['PT'](pos, interim=True)[1]).numpy().reshape((1,-1))
                data.append(np.concatenate([energy_val, bonds, angles, torsions], axis=1).reshape(-1))
            else:
                data.append(self.oneD_trajectory(eta, steps, newMs, pos, f, dim=dim-1))
            pos = get_direction(pos, eta)
        data = np.stack(data, axis=0)
        return data

    def make_trajectories(self, *args, dim=1, divs=1000, **kwargs):
        divs = divs
        steps = int(np.exp(np.log(divs)/dim))
        eta = 2*np.pi/steps

        self.modeller = Modeller(self.pdb.topology, self.init_position)

        for idx, Ls in enumerate(self.eig_groups):
            Basis = self.get_basis(self.hessian, Ls, idx, *args, **kwargs) 
            print(len(Basis))

            for basis_idx, Ms in enumerate(itertools.combinations(Basis, dim)):
                with open(f'{self.results_dir}/eigen_{idx+1}_{basis_idx+1}.pdb', 'w') as f:
                    PDBFile.writeHeader(self.modeller.topology, f)
                    data = self.oneD_trajectory(eta, steps, list(Ms), self.init_position, f, dim=dim)
                    PDBFile.writeFooter(self.modeller.topology, f)
                    np.save(f"{self.results_dir}/traj_{idx+1}_{basis_idx+1}.npy", data)

class OptDOF:
    def __init__(self, pdbname, forcefieldnames, eps=(1.0,0.1), modifiers=''):
        self.name_mod = pdbname.split("/")[-1].split(".")[0]
        self.pdb = PDBFile(pdbname)
        self.forcefield = ForceField(*forcefieldnames)
        self.energy = TorchEnergyFunction(pdbname, forcefieldname)
        self.total_energy = self.energy.total_energy
        self.mutation_type = "Opt"
        self.eps1, self.eps2 = eps[0], eps[1]
        if len(forcefieldnames)>1:
            self.name_mod += "_water"
            self.system = self.forcefield.createSystem(self.pdb.topology, nonbondedCutoff=3 * nanometer, constraints=HBonds)
        else:
            self.system = self.forcefield.createSystem(self.pdb.topology, constraints=None)
        self.results_dir = f'./results/min_op/{self.mutation_type}/{self.name_mod}{modifiers}'

    def init_sim(self):
        self.Num_parts = self.system.getNumParticles()
        integrator = LangevinIntegrator(300 * kelvin, 1 / picosecond, 2 * femtoseconds)
        
        simulation = Simulation(self.pdb.topology, self.system, integrator)
        simulation.context.setPositions(self.pdb.positions)
        simulation.minimizeEnergy()

        state = simulation.context.getState(getPositions=True)
        position = state.getPositions(asNumpy=True)
        position = position.value_in_unit(position.unit)

        init_position = position

        return init_position, simulation

    def discover_DOF(self):
        os.makedirs(self.results_dir, exist_ok=True)

        self.init_position, simulation = self.init_sim()
        n, d = self.init_position.shape
        m = 16*(n**2)
        noises = self.eps1*np.random.randn(m,n,d)/np.sqrt(n*d)
        samples = self.init_position + noises
        matrices = []
        for i in range(m):
            simulation.context.setPositions(samples[i,:,:])
            state = simulation.context.getState(getForces=True)
            force = state.getForces(asNumpy=True)
            force = force.value_in_unit(force.unit)
            matrices.append(noises[i,:,:]@force.T)
            # print(force.T@np.ones((n,1)))

        M = torch.Tensor(sum([matrix.reshape(-1,1)@(matrix.reshape(-1,1).T) for matrix in matrices]))
        eigs_P, eigv_P = spectrum_abs(M)
      
        eigs_P_normalized = (eigs_P/torch.amax(eigs_P)).reshape(-1)
        print(eigs_P_normalized[:2])

        nB = 40
        self.eig_groups = [torch.Tensor(eigv_P[:nB,:].reshape(nB, n, n))]

        n, d = self.init_position.shape
        m = 16*(n**2)
        noises = self.eps2*np.random.randn(m,n,d)/np.sqrt(n*d)
        samples = self.init_position + noises
        matrices = []
        for i in range(m):
            simulation.context.setPositions(samples[i,:,:])
            state = simulation.context.getState(getForces=True)
            force = state.getForces(asNumpy=True)
            force = force.value_in_unit(force.unit)
            matrices.append(noises[i,:,:]@force.T)
        self.hessian = torch.Tensor([matrix.reshape(-1,1) for matrix in matrices]).reshape((m, -1))
        print(self.hessian.shape)

    def get_basis(self, hessian, Ls, *args, top_K=2, **kwargs):

        nB, n, _ = Ls.shape
        basis = Ls
        V = self.hessian@(basis.reshape(nB,n**2).T)
        M = V.T@V

        _, eigv_R = spectrum_abs(M)
        transform = lambda x : torch.tensordot(x, basis, dims=1)

        return [transform(eigv_R[:,-i]) for i in range(1, top_K+1)]
        # Lambda = hessian.T@hessian
        # basis = Ls
        # M = torch.zeros(basis.shape[0], basis.shape[0])
        # transform = lambda x : torch.tensordot(x, basis, dims=1)
        # x0 = torch.Tensor(self.init_position)
        # for j in range(x0.shape[1]):
        #     V = torch.tensordot(basis, x0[:, j], dims=1)
        #     M += V@Lambda@(V.T)
        
        # _, eigv_R = spectrum_abs(M)

        # return [transform(eigv_R[:,-i]) for i in range(1, top_K+1)]

    def oneD_trajectory(self, eta, steps, Ms, init_pos, f, dim=1):
        M, newMs = Ms[0], Ms[1:]
        data = []

        def get_direction(pos, eta):
            return pos + eta*M@pos
        
        pos = torch.Tensor(init_pos)
        for step in range(steps):
            if dim == 1:
                PDBFile.writeModel(self.modeller.topology, 10*pos.numpy(),  file=f, modelIndex=step)
                energy_val = self.energy.forces['HB'](pos, interim=True)[0].numpy().reshape((-1,1))
                bonds = (self.energy.forces['HB'](pos, interim=True)[1]).numpy().reshape((1,-1))
                angles = (self.energy.forces['HA'](pos, interim=True)[1]).numpy().reshape((1,-1))
                torsions = (self.energy.forces['PT'](pos, interim=True)[1]).numpy().reshape((1,-1))
                data.append(np.concatenate([energy_val, bonds, angles, torsions], axis=1).reshape(-1))
            else:
                data.append(self.oneD_trajectory(eta, steps, newMs, pos, f, dim=dim-1))
            pos = get_direction(pos, eta)
        data = np.stack(data, axis=0)
        return data

    def make_trajectories(self, *args, dim=1, divs=1000, **kwargs):
        divs = divs
        steps = int(np.exp(np.log(divs)/dim))
        eta = 2*np.pi/steps

        self.modeller = Modeller(self.pdb.topology, self.init_position)

        for idx, Ls in enumerate(self.eig_groups):
            Basis = self.get_basis(self.hessian, Ls, idx, *args, **kwargs) 
            print(len(Basis))

            for basis_idx, Ms in enumerate(itertools.combinations(Basis, dim)):
                with open(f'{self.results_dir}/eigen_{idx+1}_{basis_idx+1}_{self.eps1}_{self.eps2}.pdb', 'w') as f:
                    PDBFile.writeHeader(self.modeller.topology, f)
                    data = self.oneD_trajectory(eta, steps, list(Ms), self.init_position, f, dim=dim)
                    PDBFile.writeFooter(self.modeller.topology, f)
                    np.save(f"{self.results_dir}/traj_{idx+1}_{basis_idx+1}_{self.eps1}_{self.eps2}.npy", data)


if __name__ == "__main__":

    for i in range(3):
        pdbname ="molecule_files/alanine-dipeptide.pdb"
        forcefieldname = "amber99sbnmr.xml"
        forcefieldnames = ["amber99sbnmr.xml", "amber99_obc.xml"]
    
        eps = (0.1, 0.01)
        Opt = OptDOF(pdbname, forcefieldnames, eps=eps, modifiers=f'_run_{i}')
        Opt.discover_DOF()
        Opt.make_trajectories(dim=2, divs=1000, top_K=2)

        eps = (0.5, 0.01)
        Opt = OptDOF(pdbname, forcefieldnames, eps=eps, modifiers=f'_run_{i}')
        Opt.discover_DOF()
        Opt.make_trajectories(dim=2, divs=1000, top_K=2)

        forcefieldnames = ["amber99sbnmr.xml"]#, "amber99_obc.xml"]
    
        eps = (0.1, 0.01)
        Opt = OptDOF(pdbname, forcefieldnames, eps=eps, modifiers=f'_run_{i}')
        Opt.discover_DOF()
        Opt.make_trajectories(dim=2, divs=1000, top_K=2)

        eps = (0.5, 0.01)
        Opt = OptDOF(pdbname, forcefieldnames, eps=eps, modifiers=f'_run_{i}')
        Opt.discover_DOF()
        Opt.make_trajectories(dim=2, divs=1000, top_K=2)
    
    pdbname ="molecule_files/alanine-dipeptide.pdb"
    forcefieldname = "amber99sbnmr.xml"
    forcefieldnames = ["amber99sbnmr.xml", "amber99_obc.xml"]

    eps = (0.1, 0.01)
    Opt = OptDOF(pdbname, forcefieldnames, eps=eps)
    Opt.discover_DOF()
    Opt.make_trajectories(dim=2, divs=1000, top_K=2)

    eps = (0.5, 0.01)
    Opt = OptDOF(pdbname, forcefieldnames, eps=eps)
    Opt.discover_DOF()
    Opt.make_trajectories(dim=2, divs=1000, top_K=2)

    forcefieldnames = ["amber99sbnmr.xml"]

    eps = (0.1, 0.01)
    Opt = OptDOF(pdbname, forcefieldnames, eps=eps)
    Opt.discover_DOF()
    Opt.make_trajectories(dim=2, divs=1000, top_K=2)

    eps = (0.5, 0.01)
    Opt = OptDOF(pdbname, forcefieldnames, eps=eps)
    Opt.discover_DOF()
    Opt.make_trajectories(dim=2, divs=1000, top_K=2)

    TrueDOF = TrueDOF(pdbname, forcefieldname)
    TrueDOF.discover_DOF()
    TrueDOF.make_trajectories(dim=2, divs=1000, top_K=2)

    LieDOF = Lie(pdbname, forcefieldname)
    LieDOF.discover_DOF()
    LieDOF.make_trajectories(dim=2, divs=1000, top_K=2)

    # Cannon = CannonicalLie(pdbname, forcefieldname)
    # Cannon.discover_DOF()
    # Cannon.make_trajectories(dim=1, divs=1000)