import numpy as np
import dgl
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import KFold
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit import DataStructs
from rdkit import Chem

import pandas as pd
import pickle


def get_target_sets(dataloader):
    labels = []
    data_points = []
    smiles = []
    for data_point, label, smile in dataloader:
        labels += label.detach().tolist()
        smiles += smile
        data_point_list = dgl.unbatch(data_point)
        for data_point in data_point_list:
            data_points.append(dgl.batch([data_point]))
    return data_points, np.array(labels), smiles


def get_distribution(dist_values):
    params = {'bandwidth': 10 ** np.linspace(-1.3, 1.5, 25)}
    grid = GridSearchCV(KernelDensity(kernel='gaussian'), params, cv=KFold(n_splits=10))
    grid.fit(dist_values)
    kde = KernelDensity(kernel='gaussian', bandwidth=grid.best_params_['bandwidth'])
    print(grid.best_params_['bandwidth'])

    kde.fit(dist_values)
    return kde, grid.best_params_['bandwidth']

def compute_weights(
    calibrate_dist_values, test_dist_values, energies=None
):  
    calib_density, best_bandwidth = get_distribution(calibrate_dist_values)
    test_density, best_bandwidth2 = get_distribution(test_dist_values)
    test_probs = test_density.score_samples(calibrate_dist_values)
    calib_probs = calib_density.score_samples(calibrate_dist_values)
    weights = np.exp(test_probs)/np.exp(calib_probs)
    weights = weights/np.sum(weights)
    
    return weights, best_bandwidth

class conformalPredictor:

    def __init__(self, data):
        self.batch_size = 64
        self.calib_pvalues = data["calib"]["predictions"] 
        self.test_pvalues = data["test"]["predictions"] 
        self.calib_energies = data["calib"]["energies"]
        self.test_energies = data["test"]["energies"]
        self.calib_features = data["calib"]["features"]
        self.test_features = data["test"]["features"]
        self.calib_labels = data["calib"]["labels"]
        self.test_labels = data["test"]["labels"]
        self.calib_smiles = data["calib"]["smiles"]
        self.test_smiles = data["test"]["smiles"]


    def get_fp_dist_values(self,smiles_data_val, smiles_data_test):
        k = 100
        mols_val = [Chem.MolFromSmiles(x) for x in smiles_data_val]
        mols_test = [Chem.MolFromSmiles(x) for x in smiles_data_test]
        fps_val = [
            FingerprintMols.FingerprintMol(
                x,
                minPath=1,
                maxPath=7,
                fpSize=2048,
                bitsPerHash=2,
                useHs=True,
                tgtDensity=0.0,
                minSize=128,
            )
            for x in mols_val
        ]
        fps_test = [
            FingerprintMols.FingerprintMol(
                x,
                minPath=1,
                maxPath=7,
                fpSize=2048,
                bitsPerHash=2,
                useHs=True,
                tgtDensity=0.0,
                minSize=128,
            )
            for x in mols_test
        ]

        nfps_val = len(fps_val)
        nfps_test = len(fps_test)

        sims_val_val = np.zeros((nfps_val, nfps_val))
        sims_val_test = np.zeros((nfps_test, nfps_val))

        for i in range(1, nfps_val):
            sims = DataStructs.BulkTanimotoSimilarity(fps_val[i], fps_val[:i])
            sims_val_val[i, :i] = sims
            sims_val_val[:i, i] = sims

        for i in range(0, nfps_test):
            sims = DataStructs.BulkTanimotoSimilarity(fps_test[i], fps_val)
            sims_val_test[i] = sims

        top_k_sims_val = [
            [np.mean(np.sort(sims_val_val[i])[-1 * k :])] for i in range(len(sims_val_val))
        ]
        top_k_sims_test = [
            [np.mean(np.sort(sims_val_test[i])[-1 * k :])]
            for i in range(len(sims_val_test))
        ]

        return top_k_sims_val, top_k_sims_test

    def weighted_logistic_calibrate(self, calib_features, test_features):
        import sklearn
        model = sklearn.linear_model.LogisticRegression(
            penalty="l2", C=1.0, solver="liblinear"
        )
        calib_labels = np.zeros(len(calib_features))
        test_labels = np.ones(len(test_features))
        features = np.concatenate((calib_features, test_features))
        labels = np.concatenate((calib_labels, test_labels))
        model.fit(features, labels)
        calib_probs = model.predict_proba(calib_features)[:, 1]
        ratios = calib_probs / (1 - calib_probs)
        return ratios

    def weighted_calibrate(self,feat_type="energies",calibration_type="mondrian",test_indices = None):

        if feat_type == "features":
            calib_dist_values =  self.calib_features
            test_dist_values = self.test_features
            weights, best_bandwidth = compute_weights(calib_dist_values, test_dist_values)

        elif feat_type == "energies":
            div_factor  = 10
            calib_dist_values = self.calib_energies
            test_dist_values = self.test_energies

        elif feat_type == "fp":
            calib_dist_values, test_dist_values =  self.get_fp_dist_values(
                self.calib_smiles,
                self.test_smiles
            )
            weights, best_bandwidth = compute_weights(calib_dist_values, test_dist_values)

        
        if feat_type == "energies":
            div_factor = 10
            calib_dist_values =  np.exp(self.calib_energies/div_factor)
            test_dist_values = np.exp(test_dist_values/div_factor)
            if test_indices is not None:
                test_dist_values = test_dist_values[test_indices]
            weights, best_bandwidth = compute_weights(calib_dist_values, test_dist_values)

            if best_bandwidth == 0.1:
                div_factor -=5
                calib_dist_values =  np.exp(self.calib_energies/div_factor)
                test_dist_values = np.exp(self.test_energies/div_factor)
                if test_indices is not None:
                    test_dist_values = test_dist_values[test_indices]
                weights, best_bandwidth = compute_weights(calib_dist_values, test_dist_values)
            
            elif best_bandwidth == 10.0:
                div_factor +=5
                calib_dist_values =  np.exp(self.calib_energies/div_factor)
                test_dist_values = np.exp(self.test_energies/div_factor)
                if test_indices is not None:
                    test_dist_values = test_dist_values[test_indices]
                weights, best_bandwidth = compute_weights(calib_dist_values, test_dist_values)
            
            print(div_factor)
        
        elif feat_type == "logistic":
            calib_dist_values = self.calib_energies
            test_dist_values = self.test_energies
            weights = self.weighted_logistic_calibrate(self.calib_features, self.test_features)
        
        else:
            if test_indices is not None:
                test_dist_values = test_dist_values[test_indices]
            weights,best_bandwidth = compute_weights(calib_dist_values, test_dist_values)

        self.calibrate(weights=weights,calibration_type=calibration_type)
        
  
    def calibrate(
        self, weights=None, calibration_type="mondrian"
    ):

        labels = self.calib_labels

        if weights is None:
            weights = [1.0 for i in range(len(labels))]

        pvalues = self.calib_pvalues


        label_0_indices = np.where(np.array(labels) == 0)[0]
        label_1_indices = np.where(np.array(labels) == 1)[0]

        pvalues_0 = np.array([pvalues[i][0] for i in label_0_indices])
        pvalues_1 = np.array([pvalues[i][1] for i in label_1_indices])

        weights_0 = np.array([weights[i] for i in label_0_indices])
        weights_1 = np.array([weights[i] for i in label_1_indices])

        indices = [np.argsort(pvalues_0), np.argsort(pvalues_1)]
        pvalues = [pvalues_0[indices[0]], pvalues_1[indices[1]]]
        weights = [weights_0[indices[0]], weights_1[indices[1]]]

        percentiles = [
            1.0 * weights[0].cumsum() / weights[0].sum() * 100,
            1.0 * weights[1].cumsum() / weights[1].sum() * 100,
        ]

        alphas = [0.4,0.3,0.2,0.1,0.05]
        self.conformal_cutoffs = {}
        conformal_cutoffs = np.zeros(2)
        for alpha in alphas:
            conformal_cutoffs[0] = percentiles[0][np.where(percentiles[0] >= alpha * 100)[0] - 1]
            conformal_cutoffs[1] = percentiles[1][np.where(percentiles[1] >= alpha * 100)[0] - 1]
        self.conformal_cutoffs[alpha] = conformal_cutoffs



    def predict_conformals(self, alpha):
        
        pvalues = self.test_pvalues
        conformal_cutoffs = self.conformal_cutoffs[alpha]
        conformal_predictions = np.zeros((len(pvalues), 2))
        for i in range(len(pvalues)):
            conformal_predictions[i][0] = pvalues[i][0] >= conformal_cutoffs[0]
            conformal_predictions[i][1] = pvalues[i][1] >= conformal_cutoffs[1]

        return np.array(conformal_predictions)

    def get_coverage(self, conformal_predictions, labels):
        label_0_indices = np.where(np.array(labels) == 0)[0]
        label_1_indices = np.where(np.array(labels) == 1)[0]

        count = 0
        for idx in label_0_indices:
            if conformal_predictions[idx][labels[idx]] == 1:
                count += 1
        class_0_coverage = count / len(label_0_indices)

        count = 0
        for idx in label_1_indices:
            if conformal_predictions[idx][labels[idx]] == 1:
                count += 1
        class_1_coverage = count / len(label_1_indices)

        total_coverage = (
            len(label_0_indices) / len(labels) * class_0_coverage
            + len(label_1_indices) / len(labels) * class_1_coverage
        )

        return total_coverage, class_0_coverage, class_1_coverage

    def get_mean_width(self, conformal_predictions, labels):
        total_width = np.mean(np.sum(conformal_predictions, axis=1))
        label_0_indices = np.where(np.array(labels) == 0)[0]
        label_1_indices = np.where(np.array(labels) == 1)[0]

        class_0_width = np.mean(np.sum(conformal_predictions[label_0_indices], axis=1))
        class_1_width = np.mean(np.sum(conformal_predictions[label_1_indices], axis=1))

        return total_width, class_0_width, class_1_width

    def get_efficiency(self, conformal_predictions, labels):
        count = 0
        class_1_correct = 0
        class_0_corect = 0
        class_1_marked = 0
        class_0_marked = 0
        for idx in range(len(labels)):
            label = labels[idx]
            if conformal_predictions[idx][label] == True:
                if sum(conformal_predictions[idx]) == 1:
                    count += 1
                if label == 0:
                    class_0_corect += 1
                else:
                    class_1_correct += 1
            
            if conformal_predictions[idx][0] == True:
                class_0_marked += 1
            if conformal_predictions[idx][1] == True:
                class_1_marked += 1
        
        total_efficiency = count / len(labels)
        try:
            class_1_efficiency = class_1_correct / class_1_marked
        except:
            class_1_efficiency = 0
        try:
            class_0_efficiency = class_0_corect / class_0_marked
        except:
            class_0_efficiency = 0

        return total_efficiency, class_0_efficiency, class_1_efficiency

    def get_metrics(self, alpha,test_indices = None):

        conformal_predictions = self.predict_conformals(alpha)
        metrics = {}
        if test_indices is None:
            labels = self.test_labels
        else:
            labels = self.test_labels[test_indices]
            conformal_predictions = conformal_predictions[test_indices]
        
        num_class_0 = len(np.where(np.array(labels) == 0)[0])
        num_class_1 = len(np.where(np.array(labels) == 1)[0])

        (
            metrics["coverage"],
            metrics["class_0_coverage"],
            metrics["class_1_coverage"],
        ) = self.get_coverage(conformal_predictions, labels)
        metrics["error_rate"] = 1 - metrics["coverage"]
        metrics["efficiency"], metrics["class_0_efficiency"], metrics["class_1_efficiency"] = self.get_efficiency(conformal_predictions, labels)
        (
            metrics["mean_width"],
            metrics["class_0_width"],
            metrics["class_1_width"],
        ) = self.get_mean_width(conformal_predictions, labels)
        metrics["class_cutoffs_0"] = self.conformal_cutoffs[alpha][0]
        metrics["class_cutoffs_1"] = self.conformal_cutoffs[alpha][1]
        metrics["class_0_ratio"] = num_class_0 / len(labels)
        metrics["class_1_ratio"] = num_class_1 / len(labels)
        return metrics
 

def retrieve_model(model):
    prediction_model = model.model
    calibration_dataloader = model.dataloaders["valid_id"]
    test_dataloader = model.dataloaders["test_id"]
    data = {}
    data["test"] = {}
    data["calib"] = {}
    (
        data["test"]["data"],
        data["test"]["labels"],
        data["test"]["smiles"],
    ) = get_target_sets(test_dataloader)
    (
        data["calib"]["data"],
        data["calib"]["labels"],
        data["calib"]["smiles"],
    ) = get_target_sets(calibration_dataloader)
    return prediction_model, data["calib"], data["test"]

model = ""  # load trained model here
prediction_model, calibration_data, test_data = retrieve_model(model)
data = {}
data["test"] = test_data
data["calib"] = calibration_data
cp = conformalPredictor(data)
cp.calibrate()
print(cp.get_metrics(0.1)) # get metrics for alpha = 0.1
