'''MISATO, a database for protein-ligand interactions
    Copyright (C) 2023
                        Till Siebenmorgen  (till.siebenmorgen@helmholtz-munich.de)
                        Sabrina Benassou   (s.benassou@fz-juelich.de)
                        Filipe Menezes     (filipe.menezes@helmholtz-munich.de)
                        Erinç Merdivan     (erinc.merdivan@helmholtz-munich.de)

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 2.1 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public
    License along with this library; if not, write to the Free Software
    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA'''

import numpy as np
import scipy.spatial as ss
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_sparse import coalesce

atom_mapping = {0: 'H', 1: 'C', 2: 'N', 3: 'O', 4: 'F', 5: 'P', 6: 'S', 7: 'CL', 8: 'BR', 9: 'I', 10: 'UNK'}
md_atom_mapping = {0:0, 6:1, 7:2, 8:3, 9:4}
residue_mapping = {0: 'ALA', 1: 'ARG', 2: 'ASN', 3: 'ASP', 4: 'CYS', 5: 'CYX', 6: 'GLN', 7: 'GLU', 8: 'GLY', 9: 'HIE',
                   10: 'ILE', 11: 'LEU', 12: 'LYS', 13: 'MET', 14: 'PHE', 15: 'PRO', 16: 'SER', 17: 'THR', 18: 'TRP',
                   19: 'TYR', 20: 'VAL', 21: 'UNK'}

ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13,
                        34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23}


def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'):
    r"""
    Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.

    :param df: Protein structure in dataframe format.
    :type df: pandas.DataFrame
    :param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
    :type node_col: str, optional
    :param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features.
        Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
    :type allowable_feats: list, optional
    :param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
    :type edge_dist_cutoff: float, optional

    :return: tuple containing

        - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.

        - edges (torch.LongTensor): Edges in COO format

        - edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.

        - node_pos (torch.FloatTensor): x-y-z coordinates of each node
    :rtype: Tuple
    """

    allowable_feats = atom_mapping

    try:
        node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
        kd_tree = ss.KDTree(node_pos)
        edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
        edges = torch.LongTensor(edge_tuples).t().contiguous()
        edges = to_undirected(edges)
    except:
        print(f"Problem with PDB Id is {item['id']}")

    node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, allowable_feats) for e in df[feat_col]])
    edge_weights = torch.FloatTensor(
        [1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)

    return node_feats, edges, edge_weights, node_pos

def prot_df_to_point_cloud(item, df, feat_col='element'):
    r"""
    Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.

    :param df: Protein structure in dataframe format.
    :type df: pandas.DataFrame
    :param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
    :type node_col: str, optional
    :param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features.
        Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
    :type allowable_feats: list, optional
    :param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
    :type edge_dist_cutoff: float, optional

    :return: tuple containing

        - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.

        - edges (torch.LongTensor): Edges in COO format

        - edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.

        - node_pos (torch.FloatTensor): x-y-z coordinates of each node
    :rtype: Tuple
    """

    allowable_feats = atom_mapping

    try:
        node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
    except:
        print(f"Problem with PDB Id is {item['id']}")

    node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, allowable_feats) for e in df[feat_col]])
    return node_feats, node_pos


def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True):
    """
    Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
    :param df: Molecule structure in dataframe format
    :type mol: pandas.DataFrame
    :param bonds: Molecule structure in dataframe format
    :type bonds: pandas.DataFrame
    :param allowable_atoms: List containing allowable atom types
    :type allowable_atoms: list[str], optional
    :return: Tuple containing \n
        - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
        - edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
        - edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
        - node_pos (torch.FloatTensor): x-y-z coordinates of each node.
    """
    if allowable_atoms is None:
        allowable_atoms = ligand_atoms_mapping
    node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())

    if bonds is not None:
        N = df.shape[0]
        bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}
        bond_data = torch.FloatTensor(bonds)
        edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
        edge_index = edge_tuples.t().long().contiguous()

        if onehot_edges:
            bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:, -1].tolist())) + list(
                map(lambda x: bond_mapping[x], bond_data[:, -1].tolist()))
            edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
            edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)

        else:
            edge_attr = torch.cat(
                (torch.FloatTensor(bond_data[:, -1]).view(-1), torch.FloatTensor(bond_data[:, -1]).view(-1)), dim=0)
    else:
        kd_tree = ss.KDTree(node_pos)
        edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
        edge_index = torch.LongTensor(edge_tuples).t().contiguous()
        edge_index = to_undirected(edge_index)
        edge_attr = torch.FloatTensor(
            [1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
        edge_attr = edge_attr.unsqueeze(1)

    node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']])

    return node_feats, edge_index, edge_attr, node_pos


def one_of_k_encoding_unk_indices(x, allowable_set):
    """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
    one_hot_encoding = [0] * len(allowable_set)
    if x in allowable_set:
        one_hot_encoding[x] = 1
    else:
        one_hot_encoding[-1] = 1
    return one_hot_encoding


def one_of_k_encoding_unk_indices_qm(x, allowable_set):
    """Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
    one_hot_encoding = [0] * (len(allowable_set) + 1)
    if x in allowable_set:
        one_hot_encoding[allowable_set[x]] = 1
    else:
        one_hot_encoding[-1] = 1
    return one_hot_encoding