import dataclasses
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs
from rdkit.Geometry import Point3D
import logging
import numpy as np
from entity.molecule_constants import *
import entity.entity_constants as ec
import networkx as nx
logger = logging.getLogger(__name__)


def safe_index_one_hot(l, e):
    """ Return index of element e in list l. If e is not present, return the last index """
    res = [0 for i in range(len(l))]
    try:
        idx = l.index(e)
    except:
        idx = len(l) - 1
    res[idx] = 1
    return res


def mol_from_file(input_file: str, sanitize=True, calc_charges=True, remove_hs=True):
    if input_file.endswith('.mol2'):
        mol = Chem.MolFromMol2File(input_file, sanitize=sanitize, removeHs=remove_hs)
    elif input_file.endswith('.sdf'):
        supplier = Chem.SDMolSupplier(input_file, sanitize=sanitize, removeHs=remove_hs)
        mol = supplier[0]
    elif input_file.endswith('.smiles'):
        smiles = input_file.replace('.smiles', '')
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=42)
        AllChem.UFFOptimizeMolecule(mol)
    elif input_file.endswith('.pdbqt'):
        with open(input_file) as file:
            pdbqt_data = file.readlines()
        pdb_block = ''
        for line in pdbqt_data:
            pdb_block += '{}\n'.format(line[:66])
        mol = Chem.MolFromPDBBlock(pdb_block, sanitize=sanitize, removeHs=remove_hs)
    elif input_file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(input_file, sanitize=sanitize, removeHs=remove_hs)
    else:
        raise ValueError('Expect the format of the molecule_file to be '
                         'one of .mol2, .sdf, .pdbqt, .smiles, and .pdb, got {}'.format(input_file))

    try:
        if sanitize or calc_charges:
            Chem.SanitizeMol(mol)

        if calc_charges:
            # Compute Gasteiger charges on the molecule.
            try:
                AllChem.ComputeGasteigerCharges(mol)
            except:
                logger.warning('Unable to compute charges for the molecule.')

        if remove_hs:
            mol = Chem.RemoveHs(mol, sanitize=sanitize)
    except Exception as e:
        logger.warning(e)
        logger.warning("RDKit was unable to read the molecule.")
        return None

    return mol


def mol_extra_featurizer(mol):
    ringinfo = mol.GetRingInfo()
    atom_features_list = []
    for idx, atom in enumerate(mol.GetAtoms()):
        atom_features_list.append(
            safe_index_one_hot(allowable_features['possible_chirality_list'], str(atom.GetChiralTag())) + \
            safe_index_one_hot(allowable_features['possible_degree_list'], atom.GetTotalDegree()) + \
            safe_index_one_hot(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()) + \
            safe_index_one_hot(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()) + \
            safe_index_one_hot(allowable_features['possible_numH_list'], atom.GetTotalNumHs()) + \
            safe_index_one_hot(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()) + \
            safe_index_one_hot(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())) + \
            safe_index_one_hot(allowable_features['possible_is_aromatic_list'], atom.GetIsAromatic()) + \
            safe_index_one_hot(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring3_list'], ringinfo.IsAtomInRingOfSize(idx, 3)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring4_list'], ringinfo.IsAtomInRingOfSize(idx, 4)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring5_list'], ringinfo.IsAtomInRingOfSize(idx, 5)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring6_list'], ringinfo.IsAtomInRingOfSize(idx, 6)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring7_list'], ringinfo.IsAtomInRingOfSize(idx, 7)) + \
            safe_index_one_hot(allowable_features['possible_is_in_ring8_list'], ringinfo.IsAtomInRingOfSize(idx, 8))
        )

    return np.array(atom_features_list, dtype=np.float32)


def mol_pair_featurizer(mol): 
    mol_num_atoms = mol.GetNumAtoms()
    pair_feat = np.zeros((mol_num_atoms, mol_num_atoms, pair_feat_num), dtype=np.float32)
    edges = np.zeros((mol_num_atoms, mol_num_atoms, ec.edge_type_num), dtype=np.float32)
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_type = bond_types[bond.GetBondType()]
        pair_feat[start, end, edge_type] = 1
        pair_feat[end, start, edge_type] = 1
        edges[start, end, ec.edge_type_order["lig_bond"]] = 1
        edges[end, start, ec.edge_type_order["lig_bond"]] = 1
    return pair_feat, edges

def mol_make_fape_frame_idx(mol):
    def find_possible_frames(G, node):
        paths = []
        neighbors = list(G.neighbors(node))
        degree = len(neighbors)
        
        if degree == 1:
            neighbor = neighbors[0]
            for second_neighbor in G.neighbors(neighbor):
                if second_neighbor != node:
                    paths.append((neighbor, node, second_neighbor))
        else:
            for i in range(degree):
                for j in range(i + 1, degree):
                    neighbor1 = neighbors[i]
                    neighbor2 = neighbors[j]
                    paths.append((neighbor1, node, neighbor2))
                    paths.append((neighbor2, node, neighbor1))
                    
        return paths
    canonical_rank = Chem.CanonicalRankAtoms(mol)
    G = nx.Graph()
    G.add_nodes_from(canonical_rank)
    rank2id = np.zeros(len(canonical_rank), dtype=np.int32)
    for idx in range(len(canonical_rank)):
        rank2id[canonical_rank[idx]] = idx
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        G.add_edge(canonical_rank[start], canonical_rank[end])
    fape_frame_idx = []
    for idx in range(len(canonical_rank)):
        paths = find_possible_frames(G, canonical_rank[idx])
        paths.sort()
        fape_frame_idx.append(list(rank2id[x] for x in paths[0]))
        # print(idx, list(rank2id[x] for x in paths[0]))
    return np.array(fape_frame_idx, dtype=np.int32)
