"""
The molecule dataset for finetuning.
This implementation is adapted from
https://github.com/chemprop/chemprop/blob/master/chemprop/data/data.py
"""
import random
import json
import torch
from argparse import Namespace
from typing import Callable, List, Union
from torch_geometric.utils import to_undirected, k_hop_subgraph


import numpy as np
from rdkit import Chem
from torch.utils.data.dataset import Dataset

from grover.data.molfeaturegenerator import get_features_generator
from grover.data.scaler import StandardScaler

KHOP = 2

class MoleculeDatapoint:
    """A MoleculeDatapoint contains a single molecule and its associated features and targets."""

    def __init__(self,
                 line: List[str],
                 args: Namespace = None,
                 features: np.ndarray = None,
                 use_compound_names: bool = False):
        """
        Initializes a MoleculeDatapoint, which contains a single molecule.

        :param line: A list of strings generated by separating a line in a data CSV file by comma.
        :param args: Arguments.
        :param features: A numpy array containing additional features (ex. Morgan fingerprint).
        :param use_compound_names: Whether the data CSV includes the compound name on each line.
        """
        self.features_generator = None
        self.args = None
        if args is not None:
            if hasattr(args, "features_generator"):
                self.features_generator = args.features_generator
            self.args = args

        if features is not None and self.features_generator is not None:
            raise ValueError('Currently cannot provide both loaded features and a features generator.')

        self.features = features

        if use_compound_names:
            self.compound_name = line[0]  # str
            line = line[1:]
        else:
            self.compound_name = None

        self.smiles = line[0]  # str


        # Generate additional features if given a generator
        if self.features_generator is not None:
            self.features = []
            mol = Chem.MolFromSmiles(self.smiles)
            for fg in self.features_generator:
                features_generator = get_features_generator(fg)
                if mol is not None and mol.GetNumHeavyAtoms() > 0:
                    if fg in ['morgan', 'morgan_count']:
                        self.features.extend(features_generator(mol, num_bits=args.num_bits))
                    else:
                        self.features.extend(features_generator(mol))

            self.features = np.array(self.features)

        # Fix nans in features
        if self.features is not None:
            replace_token = 0
            self.features = np.where(np.isnan(self.features), replace_token, self.features)

        # Create targets
        self.targets = [float(x) if x != '' else None for x in line[1:]]

    def set_features(self, features: np.ndarray):
        """
        Sets the features of the molecule.

        :param features: A 1-D numpy array of features for the molecule.
        """
        self.features = features

    def num_tasks(self) -> int:
        """
        Returns the number of prediction tasks.

        :return: The number of tasks.
        """
        return len(self.targets)

    def set_targets(self, targets: List[float]):
        """
        Sets the targets of a molecule.

        :param targets: A list of floats containing the targets.
        """
        self.targets = targets

    def set_subgraph(self, subgraph):
        self.subgraph = subgraph


class MoleculeKGNNDataset(Dataset):
    """A MoleculeDataset contains a list of molecules and their associated features and targets."""

    def __init__(self, data: List[MoleculeDatapoint]):
        """
        Initializes a MoleculeDataset, which contains a list of MoleculeDatapoints (i.e. a list of molecules).

        :param data: A list of MoleculeDatapoints.
        """
        self.data = data
        self.args = self.data[0].args if len(self.data) > 0 else None
        self.scaler = None
        self.node_ids = None
        self.get_node_ids()
        data_path = '/data/pj20/molkg/pretrain_data'
        self.ent_type, self.motifs, self.G_tg = get_everything(data_path)

    def compound_names(self) -> List[str]:
        """
        Returns the compound names associated with the molecule (if they exist).

        :return: A list of compound names or None if the dataset does not contain compound names.
        """
        if len(self.data) == 0 or self.data[0].compound_name is None:
            return None

        return [d.compound_name for d in self.data]

    def smiles(self) -> List[str]:
        """
        Returns the smiles strings associated with the molecules.

        :return: A list of smiles strings.
        """
        return [d.smiles for d in self.data]

    def get_node_ids(self) -> List[str]:
        node_ids = []
        with open('/home/pj20/gode/data_process/smiles2id.json', 'r') as f:
            smiles2id = json.load(f)
        
        for smile in self.smiles():
            if smile in smiles2id:
                node_ids.append(smiles2id[smile])
            else:
                node_ids.append(smiles2id['UNK'])

        self.node_ids = node_ids
            
    def features(self) -> List[np.ndarray]:
        """
        Returns the features associated with each molecule (if they exist).

        :return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
        """
        if len(self.data) == 0 or self.data[0].features is None:
            return None

        return [d.features for d in self.data]

    def kges(self) -> List[np.ndarray]:
        """
        Returns the features associated with each molecule (if they exist).

        :return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
        """
        if len(self.data) == 0 or self.data[0].kges is None:
            return None

        return [d.kges for d in self.data]

    def targets(self) -> List[List[float]]:
        """
        Returns the targets associated with each molecule.

        :return: A list of lists of floats containing the targets.
        """
        return [d.targets for d in self.data]

    def num_tasks(self) -> int:
        """
        Returns the number of prediction tasks.

        :return: The number of tasks.
        """
        if self.args.dataset_type == 'multiclass':
            return int(max([i[0] for i in self.targets()])) + 1
        else:
            return self.data[0].num_tasks() if len(self.data) > 0 else None

    def features_size(self) -> int:
        """
        Returns the size of the features array associated with each molecule.

        :return: The size of the features.
        """
        return len(self.data[0].features) if len(self.data) > 0 and self.data[0].features is not None else None

    def shuffle(self, seed: int = None):
        """
        Shuffles the dataset.

        :param seed: Optional random seed.
        """
        if seed is not None:
            random.seed(seed)
        random.shuffle(self.data)

    def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler:
        """
        Normalizes the features of the dataset using a StandardScaler (subtract mean, divide by standard deviation).

        If a scaler is provided, uses that scaler to perform the normalization. Otherwise fits a scaler to the
        features in the dataset and then performs the normalization.

        :param scaler: A fitted StandardScaler. Used if provided. Otherwise a StandardScaler is fit on
        this dataset and is then used.
        :param replace_nan_token: What to replace nans with.
        :return: A fitted StandardScaler. If a scaler is provided, this is the same scaler. Otherwise, this is
        a scaler fit on this dataset.
        """
        if len(self.data) == 0 or self.data[0].features is None:
            return None

        if scaler is not None:
            self.scaler = scaler

        elif self.scaler is None:
            features = np.vstack([d.features for d in self.data])
            self.scaler = StandardScaler(replace_nan_token=replace_nan_token)
            self.scaler.fit(features)

        for d in self.data:
            d.set_features(self.scaler.transform(d.features.reshape(1, -1))[0])

        return self.scaler

    def set_targets(self, targets: List[List[float]]):
        """
        Sets the targets for each molecule in the dataset. Assumes the targets are aligned with the datapoints.

        :param targets: A list of lists of floats containing targets for each molecule. This must be the
        same length as the underlying dataset.
        """
        assert len(self.data) == len(targets)
        for i in range(len(self.data)):
            self.data[i].set_targets(targets[i])

    def sort(self, key: Callable):
        """
        Sorts the dataset using the provided key.

        :param key: A function on a MoleculeDatapoint to determine the sorting order.
        """
        self.data.sort(key=key)

    def __len__(self) -> int:
        """
        Returns the length of the dataset (i.e. the number of molecules).

        :return: The length of the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
        return (
            self.data[idx], 
            get_subgraph(
                G_tg=self.G_tg,
                center_molecule_id=self.node_ids[idx],
                motifs=self.motifs,
                ent_type=self.ent_type
            ),
            )


def get_everything(data_path):
    # Training Labels
    ## Load entity type labels
    print('Loading entity type labels...')
    ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy')) # (num_ent, num_ent_type)
    additional_ent_type = torch.zeros((1, ent_type.shape[1]), dtype=torch.long)
    ent_type = torch.cat((ent_type, additional_ent_type), dim=0)

    ## Load center molecule motifs
    print('Loading center molecule motifs...')
    motifs = []
    with open(f'{data_path}/id2motifs.json', 'r') as f:
        id2motifs = json.load(f)
    motif_len = len(id2motifs['0'])
    for i in range(len(ent_type)):
        if str(i) in id2motifs.keys():
            motifs.append(np.array(id2motifs[str(i)]))
        else:
            motifs.append(np.array([0] * motif_len))

    motifs = torch.tensor(np.array(motifs), dtype=torch.long) # (num_ent, motif_len)
    additional_motif = torch.zeros((1, motif_len), dtype=torch.long)
    motifs = torch.cat((motifs, additional_motif), dim=0)


    # Entire Knowledge Graph (MolKG)
    print('Loading entire knowledge graph...')
    with open(f'{data_path}/graph.pt', 'rb') as f:
        G_tg = torch.load(f)
    
    G_tg.num_nodes = G_tg.num_nodes

    return ent_type, motifs, G_tg

def get_k_hop_nodes(node_index, num_hops, edge_index):
    # Convert the edge_index to undirected
    edge_index = to_undirected(edge_index)

    # Compute the k-hop subgraph
    node_idx_k_hop, _, _, _ = k_hop_subgraph(node_index, num_hops, edge_index, relabel_nodes=False, num_nodes=None, flow='source_to_target')

    return node_idx_k_hop


def get_subgraph(G_tg, center_molecule_id, motifs, ent_type):
    masked_node_ids = get_k_hop_nodes(int(center_molecule_id), KHOP-1, G_tg.edge_index)
    subgraph = G_tg.subgraph(masked_node_ids)
    motif_labels = motifs[center_molecule_id]  # (num_masked_nodes, motif_len)
    node_labels = ent_type[masked_node_ids]  # (num_masked_nodes, num_ent_type)

    # Get the one-hot encoding of the 'molecule' type
    molecule_encoding = node_labels[:, 0] 
    # Find where the molecule_encoding is not 1 (i.e., not a molecule)
    non_molecule_mask = molecule_encoding != 1
    # Get the relative indices of non-molecule nodes in masked_node_ids
    non_molecule_node_ids_relative = torch.where(non_molecule_mask)[0]
    # Get the labels for non_molecule_nodes

    
    subgraph.masked_node_ids = masked_node_ids
    subgraph.center_molecule_id = torch.where(masked_node_ids == center_molecule_id)[0][0]
    subgraph.motif_labels = motif_labels
    subgraph.non_molecule_node_ids = non_molecule_node_ids_relative

    return subgraph

