from pathlib import Path
import pandas as pd
import h5py
import numpy as np
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset

from dataset.mol_protein_graph import atom_mapping, one_of_k_encoding_unk_indices, mol_df_to_graph_for_qm, prot_df_to_point_cloud


class MolProtDataset(Dataset):
    """
        Load the MD dataset
    """

    def __init__(self, md_data_file, idx_file, transform=None, post_transform=None):
        """

        Args:
            md_data_file (str): H5 file path
            idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
            transform (obj): class that convert a dict to a PyTorch Geometric graph.
            post_transform (PyTorch Geometric, optional): data augmentation. Defaults to None.
        """

        self.md_data_file = Path(md_data_file).absolute()

        with open(idx_file, 'r') as f:
            self.ids = f.read().splitlines()

        self.f = h5py.File(self.md_data_file, 'r')

        self._transform = transform

        self._post_transform = post_transform
        self.total_frame = 99

    def __len__(self) -> int:
        return len(self.ids) * self.total_frame

    def __getitem__(self, index: int):
        p_index = index // self.total_frame
        f_index = index % self.total_frame
        if not 0 <= (p_index) < len(self.ids):
            raise IndexError(p_index)
        item = {}
        pitem = self.f[self.ids[p_index]]
        item["id"] = self.ids[p_index]
        item['frame'] = f_index
        cutoff = pitem["molecules_begin_atom_index"][:][-1]
        item['cutoff'] = cutoff
        item['total_atoms'] = pitem["atoms_element"][:].size

        pos = np.concatenate([pitem["trajectory_coordinates"][:][f_index], pitem["trajectory_coordinates"][:][f_index+1]], 1)
        element = pitem["atoms_element"][:]
        node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, atom_mapping) for e in element])
        item = Data(node_feats, pos=torch.from_numpy(pos), ids=item["id"], frame=item['frame'], pro_mol_cutoff=item['cutoff'])

        if self._transform:
            item = self._transform(item)

        if self._post_transform:
            item = self._post_transform(item)

        return item

    def get_protein(self, protein_id: str):
        item = {}
        pitem = self.f[protein_id]
        item["id"] = pitem
        cutoff = pitem["molecules_begin_atom_index"][:][-1]
        item['cutoff'] = cutoff
        item['total_atoms'] = pitem["atoms_element"][:].size

        pos = torch.FloatTensor(pitem["trajectory_coordinates"][:])
        element = pitem["atoms_element"][:]
        node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e - 1, atom_mapping) for e in element])
        return pos, node_feats


class MDTransform(object):
    """
    Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph
    """

    def __init__(self, edge_dist_cutoff=4.5):
        """

        Args:
            edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5.
        """
        self.edge_dist_cutoff = edge_dist_cutoff

    def __call__(self, item):
        item = prot_graph_transform(item, atom_keys=['atoms'])
        return item

def prot_graph_transform(item, atom_keys):
    """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
    Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.

    :param item: Dataset item to transform
    :type item: dict
    :param atom_keys: list of keys to transform, where each key contains a dataframe of atoms, defaults to ['atoms']
    :type atom_keys: list, optional
    :param label_key: name of key containing labels, defaults to ['scores']
    :type label_key: str, optional
    :return: Transformed Dataset item
    :rtype: dict
    """

    for key in atom_keys:
        node_feats, pos = prot_df_to_point_cloud(item, item[key])
        item[key] = Data(node_feats, pos=pos, ids=item["id"], frame=item['frame'])

    return item['atoms']


def mol_graph_transform_for_qm(item, atom_key, label_key, allowable_atoms, use_bonds, onehot_edges, edge_dist_cutoff):
    """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
    Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.

    :param item: Dataset item to transform
    :type item: dict
    :param atom_key: name of key containing molecule structure as a dataframe, defaults to 'atoms'
    :type atom_keys: list, optional
    :param label_key: name of key containing labels, defaults to 'scores'
    :type label_key: str, optional
    :param use_bonds: whether to use molecular bond information for edges instead of distance. Assumes bonds are stored under 'bonds' key, defaults to False
    :type use_bonds: bool, optional
    :return: Transformed Dataset item
    :rtype: dict
    """

    bonds = item['bonds'] if use_bonds else None

    node_feats, edge_index, edge_feats, pos = mol_df_to_graph_for_qm(item[atom_key], bonds=bonds,
                                                                     onehot_edges=onehot_edges,
                                                                     allowable_atoms=allowable_atoms,
                                                                     edge_dist_cutoff=edge_dist_cutoff)
    item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos)

    return item