import pandas
import matplotlib.pyplot as plt
import numpy as np
from rdkit import Chem
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold
import pickle
#from moses import get_all_metrics
stage2_num = 49999
stage3_num = 49999

training_file = "../../protein_target_data/EGFR/original/train.smiles"
test_file = "../../protein_target_data/EGFR/original/test.smiles"
regression_file = "../../chemprop/EGFR/EGFR_regression.csv"
clf_path = "predictor_models/EGFR_RFC/egfr.pkl"
with open(clf_path, "rb") as f:
    clf_c = pickle.load(f)

clf_path = "predictor_models/EGFR_RFR/egfr.pkl"
with open(clf_path, "rb") as f:
    clf_r = pickle.load(f)

f_train = open(training_file, "r+")
training_data = f_train.read().split("\n")[:-1]

f_test = open(test_file, "r+")
test_data = f_test.read().split("\n")[:-1]



def validate(smiles):
    valid_smiles = []
    for s in smiles:
        m = Chem.MolFromSmiles(s)
        if m is not None:
            valid_smiles.append(s)
    return valid_smiles

def met2str(mean, std):
    metric_str = ""
    metric_str += "&$" + str("{:#.2g}".format(mean)) + "$"
    var_str = str("{:.2e}".format(std))
    if "e" in var_str:
        var_str1 = var_str.split("e")[0]
        var_str2 = var_str.split("e")[1]
        metric_str += "{\\tiny " + var_str1 + " $\mathrm{e}{" + var_str2 + "}$}"
    else:
        metric_str += "{\\tiny $" + var_str + "$}"
    return metric_str



true_mols = training_data
true_mols = [Chem.MolFromSmiles(s) for s in true_mols]
true_mols = [x for x in true_mols if x is not None]
true_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 3, 2048) for x in true_mols]

for ft_type in ["test", "", "RL", "whole_model_adam", "inner_extra_layer_adam", "outer_extra_layer_adam"]:
    active_list_cpr = []
    active_list_rfc = []
    active_list_rfr = []
    active_list_cpc = []
    nuv_active_list = []
    diversitys = []
    noveltys_9 = []
    noveltys_4 = []
    average_cp_scores = []
    average_rf_scores = []

    print ("finetuning type: " + ft_type)
    for seed in [1, 2, 3, 4, 5]:
        if len(ft_type) > 4:
            pred_file = "samples_pred_"+ft_type+"_finetuning_" + str(stage2_num) + "_" + str(stage3_num) + "." + str(seed) + ".txt" 
        elif ft_type == "test":
            pred_file = "test_pred.txt"
        elif ft_type == "RL":
            pred_file = "samples_pred_RL_finetuning." + str(seed) + ".txt"
        elif ft_type == "no":
            pred_file = "samples_pred_" + ft_type + "_finetuning_19999_19999." + str(seed) + ".txt"
        else:
            pred_file = "samples_pred." + str(seed) + ".txt"
        if len(ft_type) > 4:
            clf_file = "samples_clf_" + ft_type + "_finetuning_" + str(stage2_num) + "_" + str(\
                stage3_num) + "." + str(seed) + ".txt"
        elif ft_type == "test":
            clf_file = "test_clf.txt"
        elif ft_type == "RL":
            clf_file = "samples_clf_RL_finetuning." + str(seed) + ".txt"
        else:
            clf_file = "samples_clf." + str(seed) + ".txt"

        # chemprop regressor prediction
        df = pandas.read_csv(pred_file, sep=",", index_col=0)
        pred_values = np.array(df.loc[:, ["pX"]].values.tolist())
        total_active = sum(pred_values >= 6.)
        active_list_cpr.append(total_active/len(pred_values))
        average_cp_scores.append(np.mean(np.array(pred_values)))

        # chemprop classifier prediction
        cpc_df = pandas.read_csv(clf_file, sep=",", index_col=0)
        cpc_values = np.array(cpc_df.loc[:, ["egfr"]].values.tolist())
        cpc_active = sum(cpc_values >= .5)
        active_list_cpc.append(cpc_active / len(cpc_values))

        # novelty uniqueness and valid
        pred_mols = df.index.values.tolist()
        novel_unique_mols = set(pred_mols).difference(set(training_data))
        nuv_mols = validate(list(novel_unique_mols))
        nuv_preds = np.array(df.loc[nuv_mols, :].values.tolist())
        nuv_actives = sum(nuv_preds >= 6.)
        nuv_active_list.append(nuv_actives)

        # random forest regressor predictions
        pred_mols = [Chem.MolFromSmiles(s) for s in pred_mols]
        pred_mols = [x for x in pred_mols if x is not None]
        pred_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 3, 2048) for x in pred_mols]

        pred_np_fps = []
        for features_vec in pred_fps:
            features = np.zeros((1,))
            DataStructs.ConvertToNumpyArray(features_vec, features)
            pred_np_fps.append(features.reshape(1, -1))
        pred_np_fps = np.concatenate(pred_np_fps, axis=0)
        active_preds_r = clf_r.predict(pred_np_fps)
        active_list_rfr.append(sum(active_preds_r>= 6.)/len(active_preds_r))
        average_rf_scores.append(np.mean(np.array(active_preds_r)))

        # random forest classifier predictions
        active_preds_c = clf_c.predict(pred_np_fps)
        active_list_rfc.append(sum(active_preds_c) / len(active_preds_c))

        # novelty using tanimoto similarity, from jin et. al. 2020
        fraction_similar_9 = 0
        fraction_similar_4 = 0
        for i in range(len(pred_fps)):
            sims = DataStructs.BulkTanimotoSimilarity(pred_fps[i], true_fps)
            if max(sims) >= 0.9:
                fraction_similar_9 += 1
            if max(sims) >= 0.4:
                fraction_similar_4 += 1


        noveltys_9.append(1 - fraction_similar_9 / len(pred_mols))
        noveltys_4.append(1 - fraction_similar_4 / len(pred_mols))

        # diversity
        similarity = 0
        for i in range(len(pred_fps)):
            sims = DataStructs.BulkTanimotoSimilarity(pred_fps[i], pred_fps[:i])
            similarity += sum(sims)

        n = len(pred_fps)
        n_pairs = n * (n - 1) / 2
        diversity = 1 - similarity / n_pairs
        diversitys.append(diversity)

        plt.hist(active_preds_r,
             bins=[0.0, 0.5, 1.0, 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10.])
        plt.xlabel('Activity Score')
        plt.ylabel('Number of Molecules')
        if len(ft_type) > 4:
            plt.savefig("sample_histogram_RF_"+ft_type+"_finetuning" + str(stage2_num) + "_" + str(stage3_num) + "." + str(seed) + ".pdf")
        elif ft_type == "RL":
            plt.savefig("sample_histogram_RF_RL_finetuning." + str(seed) + ".pdf")
        elif ft_type == "test":
            plt.savefig("test_histogram_RF.pdf")
        else:
            plt.savefig("sample_histogram_RF_1stage."+str(seed)+".pdf")
        plt.close()

    actives_cpr = np.array(active_list_cpr)
    actives_rfc = np.array(active_list_rfc)
    actives_cpc = np.array(active_list_cpc)
    actives_rfr = np.array(active_list_rfr)
    nuv_actives = np.array(nuv_active_list)
    diversitys = np.array(diversitys)
    noveltys_9 = np.array(noveltys_9)
    noveltys_4 = np.array(noveltys_4)
    average_cp = np.array(average_cp_scores)
    average_rf = np.array(average_rf_scores)
    print ("CP Classification Active Mean: " + str(np.mean(actives_cpc)) + " Var: " + str(np.std(actives_cpc)))
    print ("Classification Active Mean: " + str(np.mean(actives_rfc)) + " Var: " + str(np.std(actives_rfc)))
    print ("Regression Active Mean: " + str(np.mean(actives_cpr)) + " Var: " + str(np.std(actives_cpr)))
    print ("RF Regression Active Mean: " + str(np.mean(actives_rfr)) + " Var: " + str(np.std(actives_rfr)))
    print ("NUV Active Mean: " + str(np.mean(nuv_actives)) + " Var: " + str(np.std(nuv_actives)))
    print ("Diversity Mean: " + str(np.mean(diversitys)) + " Var: " + str(np.std(diversitys)))
    print ("Novelty 4 Mean: " + str(np.mean(noveltys_4)) + " Var: " + str(np.std(noveltys_4)))
    print ("Average RF Mean: " + str(np.mean(average_rf)) + " Var: " + str(np.std(average_rf)))
    print ("Average CP Mean: " + str(np.mean(average_cp)) + " Var: " + str(np.std(average_cp)))
    print (met2str(np.mean(actives_cpc), np.std(actives_cpc)) + \
           met2str(np.mean(actives_cpr), np.std(actives_cpr)) + met2str(np.mean(average_cp), np.std(average_cp)) +
           met2str(np.mean(actives_rfc), np.std(actives_rfc)) + \
           met2str(np.mean(actives_rfr), np.std(actives_rfr)) + met2str(np.mean(average_rf), np.std(average_rf)) +
           met2str(np.mean(diversitys), np.std(diversitys)) + \
           met2str(np.mean(noveltys_4), np.std(noveltys_4))
           )


