#!/usr/bin/env python
import numpy as np
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit import DataStructs
from sklearn import svm
import pickle
import re
import os.path as op
rdBase.DisableLog('rdApp.error')

"""Scores based on an ECFP classifier for activity."""

clf_model = None
"""
def load_model():
    global clf_model
    #name = op.join(op.dirname(__file__), 'clf_py27.pkl')
    name = op.join(op.dirname(__file__), 'clf_py36.pkl')
    with open(name, "rb") as f:
        clf_model = pickle.load(f)
"""
#Directly loads the weights of the drd2 estimator without clf_py36,
#which only works with python 3.6
def load_model():
    global clf_model

    """
    clf_py36.pkl's parameters are:
    {'decision_function_shape': None, 
    '_impl': 'c_svc', 
    'kernel': 'rbf', 
    'degree': 3, 
    'gamma': 0.015625, 
    'coef0': 0.0, 
    'tol': 0.001, 
    'C': 128, 
    'nu': 0.0, 
    'epsilon': 0.0, 
    'shrinking': True, 
    'probability': True, 
    'cache_size': 200, 
    'class_weight': None, 
    'verbose': False, 
    'max_iter': -1, 
    'random_state': None, 
    '_sparse': False, 
    'class_weight_': array([1., 1.]), 
    'classes_': array([0, 1]), 
    '_gamma': 0.015625, 
    'support_': array([   6,   12,   19, ..., 9928, 9941, 9993], dtype=int32), 
    'support_vectors_': array([[13.,  0.,  0., ...,  0.,  0.,  0.],
       [ 3.,  1.,  2., ...,  0.,  0.,  0.],
       [ 8.,  0.,  3., ...,  0.,  0.,  0.],
       ...,
       [ 5.,  0.,  1., ...,  0.,  0.,  0.],
       [11.,  1.,  3., ...,  0.,  0.,  0.],
       [ 9.,  1.,  3., ...,  0.,  0.,  0.]]), 
    'n_support_': array([1655,  504], dtype=int32), 
    'dual_coef_': array([[-0.46832267, -0.05979846, -0.04914022, ...,  0.45076181,  0.97044047,  0.25150248]]), 
    'intercept_': array([-0.7125836]), 
    'probA_': array([-6.21194521]), 
    'probB_': array([0.88511078]),
    'fit_status_': 0, 
    'shape_fit_': (10000, 2048),
    '_intercept_': array([0.7125836]), 
    '_dual_coef_': array([[ 0.46832267,  0.05979846,  0.04914022, ..., -0.45076181, -0.97044047, -0.25150248]])}
    """
    clf_model = svm.SVC(decision_function_shape=None,
                        kernel='rbf',
                        degree=3,
                        gamma=0.015625,
                        coef0=0.0,
                        tol=0.001,
                        C=128,
                        # nu=0.0,
                        # epsilon=0.0,
                        shrinking=True,
                        probability=True,
                        cache_size=200,
                        class_weight=None,
                        verbose=False,
                        max_iter=-1,
                        random_state=None)
    clf_model._impl='c_svc'
    clf_model._sparse=False
    clf_model._gamma=0.015625
    clf_model.fit_status_=0 
    clf_model.shape_fit_=(10000, 2048)

    name = op.join(op.dirname(__file__), 'clf_py36_weights.npz')
    weight_data = np.load(name)
    clf_model.class_weight_=weight_data['class_weight_']
    clf_model.classes_=weight_data['classes_']
    clf_model.support_=weight_data['support_']
    clf_model.support_vectors_=weight_data['support_vectors_']
    clf_model._n_support=weight_data['n_support_']
    clf_model.dual_coef_=weight_data['dual_coef_']
    clf_model.intercept_=weight_data['intercept_']
    clf_model._probA=weight_data['probA_']
    clf_model._probB=weight_data['probB_']
    clf_model._intercept_=weight_data['_intercept_']
    clf_model._dual_coef_=weight_data['_dual_coef_']

def get_score(smile):
    if clf_model is None:
        load_model()

    mol = Chem.MolFromSmiles(smile)
    if mol:
        fp = fingerprints_from_mol(mol)
        score = clf_model.predict_proba(fp)[:, 1]
        return float(score)
    return 0.0

def fingerprints_from_mol(mol):
    fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)
    size = 2048
    nfp = np.zeros((1, size), np.int32)
    for idx,v in fp.GetNonzeroElements().items():
        nidx = idx%size
        nfp[0, nidx] += int(v)
    return nfp

"""
load_model()
print(type(clf_model))
print(clf_model.__dict__)
np.savez("clf_py36_weights.npz",
         class_weight_=clf_model.class_weight_,
         classes_=clf_model.classes_,
         support_=clf_model.support_,
         support_vectors_=clf_model.support_vectors_,
         n_support_=clf_model.n_support_,
         dual_coef_=clf_model.dual_coef_,
         intercept_=clf_model.intercept_,
         probA_=clf_model.probA_,
         probB_=clf_model.probB_,
         _intercept_=clf_model._intercept_,
         _dual_coef_=clf_model._dual_coef_)
"""