from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
import critic_utils
import numpy as np
import torch
import random
from rdkit import Chem
from scipy.spatial import distance_matrix
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem import AllChem
random.seed(0)
import os
import logging
import glob
import sascorer

def get_atom_feature(m, is_ligand=True):
    n = m.GetNumAtoms()
    H = []
    H1 = []
    if is_ligand:
        for i in range(n):
            features_ligand, size_ligand = critic_utils.atom_feature_ligand(m, i, None, None)
            H1.append(features_ligand)
        H = np.array(H1)
    else:
        for i in range(n):
            features_protein, size_protein = critic_utils.atom_feature_protein(m, i, None, None)
            H.append(features_protein)
        H = np.array(H)
    return H



class spMolDataset(Dataset):

    def __init__(self, pocket_dir, query_protein, ligand_smiles_list):

        self.protein_dir = pocket_dir
        self.query_protein = query_protein
        self.ligand_smiles_list = ligand_smiles_list

    def __len__(self):
        return len(self.ligand_smiles_list)

    def __getitem__(self, idx):
        try:
            pocket_pdb = glob.glob(os.path.join(self.protein_dir, self.query_protein+"/*"))[0]
            ligand = self.ligand_smiles_list[idx]
            protein = Chem.MolFromPDBFile(pocket_pdb)
            l = ligand
            n1 = l.GetNumAtoms()
            adj1 = GetAdjacencyMatrix(l) + np.eye(n1)
            H1 = get_atom_feature(l, True)
            # prepare protein
            n2 = protein.GetNumAtoms()
            adj2 = GetAdjacencyMatrix(protein) + np.eye(n2)
            H2 = get_atom_feature(protein, False)
            # no aggregation here
            # node indice for aggregation - kept to be used later on in the model
            valid = np.zeros((n1 + n2,))
            valid[:n1] = 1
            # pIC50 to class
            # if n1+n2 > 300 : return Nonei
            sample = {
                'H1': H1,
                'H2': H2,
                'A1': adj1,
                'A2': adj2,
                'V': valid,
                'sa_score': sascorer.calculateScore(l) / 10,
                'key': Chem.MolToSmiles(l)
            }
            return sample
        except Exception as e:
            logging.info("Execption "+str(e))  
            return None


class dfDTISampler(Sampler):

    def __init__(self, weights, num_samples, replacement=True):
        weights = np.array(weights) / np.sum(weights)
        self.weights = weights
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        # return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
        retval = np.random.choice(len(self.weights), self.num_samples, replace=self.replacement, p=self.weights)
        return iter(retval.tolist())

    def __len__(self):
        return self.num_samples


def spcollate_fn(batch):
    batch = [item for item in batch if item is not None]
    logging.info("batch size:" + str(len(batch)))
    if len(batch) == 0:
        H1 = np.array(0)
        H2 = np.array(0)
        A1 = np.array(0)
        A2 = np.array(0)
        V = np.array(0)
        sa_scores = 0
        keys = 0
    # logging.info("batch size:"+str(len(batch)))
    else:
        max_natoms_protein = max([len(item['H2']) for item in batch if item is not None])
        max_nbonds_ligand = max([len(item['H1']) for item in batch if item is not None])
        max_natoms_ligand = max([len(item['A1']) for item in batch if item is not None])

        natom = max_natoms_ligand + max_natoms_protein
        H1_feature_size = batch[0]['H1'].shape[1]
        H2_feature_size = batch[0]['H2'].shape[1]
        H1 = np.zeros((len(batch), max_nbonds_ligand, H1_feature_size))  # ligand
        H2 = np.zeros((len(batch), max_natoms_protein, H2_feature_size))  # protein
        A1 = np.zeros((len(batch), max_natoms_ligand, max_natoms_ligand))
        A2 = np.zeros((len(batch), max_natoms_protein, max_natoms_protein))
        # Y = np.zeros((len(batch),))
        V = np.zeros((len(batch), natom))
        keys = []
        sa_scores = []

        for i in range(len(batch)):
            nbonds1 = len(batch[i]['H1'])
            natom1 = len(batch[i]['A1'])
            natom2 = len(batch[i]['H2'])
            natom = natom1 + natom2
            H1[i, :nbonds1] = batch[i]['H1']
            H2[i, :natom2] = batch[i]['H2']
            A1[i, :natom1, :natom1] = batch[i]['A1']
            A2[i, :natom2, :natom2] = batch[i]['A2']
            # Y[i] = batch[i]['Y']
            V[i, :natom] = batch[i]['V']
            sa_scores.append(batch[i]['sa_score'])
            keys.append(batch[i]['key'])

    H1 = torch.from_numpy(H1).float()
    H2 = torch.from_numpy(H2).float()
    A1 = torch.from_numpy(A1).float()
    A2 = torch.from_numpy(A2).float()
    # Y = torch.from_numpy(Y).float()
    V = torch.from_numpy(V).float()
    return H1, H2, A1, A2, V, sa_scores, keys





