import numpy as np
import os
import pandas as pd
import random

random_seed = 0
np.random.seed(random_seed)
random.seed(random_seed)

from tdc.single_pred import ADME

from rdkit import Chem
from rdkit.Chem import rdMolDescriptors

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error
from scipy.optimize import minimize
from scipy.special import expit  # sigmoid

random_seed_list = [0, 1, 2, 3, 4] # 5 different train/test splits. Add or remove random seeds as needed
dataset_name_list = [
                'PPBR_AZ', 
                'Clearance_Hepatocyte_AZ', 
                'Clearance_Microsome_AZ', 
                'Caco2_Wang', 
                'VDss_Lombardo', 
                'Half_Life_Obach', 
                'HydrationFreeEnergy_FreeSolv', 
                'Lipophilicity_AstraZeneca', 
                'Solubility_AqSolDB'
                ] # 9 regression datasets from TDC ADME. Add or remove datasets as needed

ranker_accuracy_list = [0.5, 0.525, 0.55, 0.575, 0.6, 0.625, 0.65, 0.675, 0.7, 0.725, 0.75, 0.775, 0.80, 0.825, 0.85, 0.875, 0.90, 0.925, 0.95, 0.975, 1.0] # List of ranker accuracies for the oracle sweep. Add or remove values as needed.
k_list = [3, 5, 7, 10, 20, 30] # List of number of reference samples to use for the oracle sweep. Add or remove values as needed.

### A single evaluation for a single dataset would look like this:
# random_seed_list = [0]
# dataset_name_list = ['PPBR_AZ']
# ranker_accuracy_list = [0.5] # ranker accuracy = 50%
# k_list = [3] # use 3 reference samples for each test sample

def estimate_score_from_threshold_comparisons(y_is_greater_than, y_is_lower_than):
    def neg_log_likelihood(s):
        s = s[0]
        loss = 0.0
        for y in y_is_greater_than:
            loss -= np.log(expit(s - y))
        for y in y_is_lower_than:
            loss -= np.log(1 - expit(s - y))
        return loss

    def grad(s):
        s = s[0]
        g = 0.0
        for y in y_is_greater_than:
            g -= 1 - expit(s - y)
        for y in y_is_lower_than:
            g += expit(s - y)
        return np.array([g])

    def hess(s):
        s = s[0]
        h = 0.0
        for y in y_is_greater_than + y_is_lower_than:
            p = expit(s - y)
            h += p * (1 - p)
        return np.array([[h]])

    # Initialize at mean of all thresholds
    all_thresholds = y_is_greater_than + y_is_lower_than
    s_init = np.array([np.mean(all_thresholds) if all_thresholds else 0.0])
    result = minimize(neg_log_likelihood, s_init, jac=grad, hess=hess, method='trust-ncg')
    s_hat = result.x[0]

    # Inverse Hessian = variance
    hessian = hess([s_hat])
    var = 1.0 / hessian[0, 0] if hessian[0, 0] > 0 else np.inf

    return s_hat, var

def rankrefine_bayes(y_reg, y_reg_sigma, y_is_higher_than, y_is_lower_than, c_scale = 5):
    # Inverse variance weighting with score estimation
    # y_is_higher_than is a list of values y_i for which the ranker prediction is y_rank > y_i

    y_reg_sigma2 = y_reg_sigma ** 2
    y_rank_estimate, y_rank_sigma2 = estimate_score_from_threshold_comparisons(y_is_higher_than, y_is_lower_than)
    y_rank_sigma2 = max(y_rank_sigma2, c_scale*y_reg_sigma2) # Avoid overconfidence
    
    sigma_post2 = 1 / (1/y_reg_sigma2 + 1/y_rank_sigma2)
    y_refined = sigma_post2 * (y_reg / y_reg_sigma2 + y_rank_estimate / y_rank_sigma2)

    return y_refined, y_rank_estimate, y_rank_sigma2

def projection_approach(y_reg, y_is_higher_than, y_is_lower_than):
    # Euclidean Projection
    lower_bound = np.max(y_is_higher_than).item() if len(y_is_higher_than) > 0 else -np.inf
    upper_bound = np.min(y_is_lower_than).item() if len(y_is_lower_than) > 0 else np.inf

    if (lower_bound > upper_bound) or (len(y_is_higher_than) == 0 )and (len(y_is_lower_than) == 0):
        return y_reg

    y_star = np.clip(y_reg, lower_bound, upper_bound)
    return y_star

def resplit_dfs(train_df, val_df, test_df, n_samples = 50, rand_seed = 0):
    # Sample n_samples rows from the training set and concatenate the rest to the test set
    idx_to_sample = train_df.sample(n_samples, random_state=rand_seed).index
    new_train_df = train_df.loc[idx_to_sample]
    new_test_df = pd.concat([test_df, train_df.drop(idx_to_sample)], axis=0)
    return {'train': new_train_df, 'test': new_test_df, 'valid': val_df}

def featurize_df(df_dict, dataset_name):
    # Convert SMILES to molecular fingerprints using RDKit
    processed_dict = {'train': {}, 'test': {}, 'valid': {}}
    for split in ['train', 'test', 'valid']:
        smiles = df_dict[split]['Drug']
        y = df_dict[split]['Y']
        if dataset_name in ['Half_Life_Obach']:
            y = np.log10(y)

        fingerprints = [rdMolDescriptors.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius=4, nBits=2048) for smi in smiles]
        fingerprints = np.array(fingerprints)

        processed_dict[split]['fingerprints'] = fingerprints
        processed_dict[split]['y'] = y
    return processed_dict

def fit_and_pred_RF(processed_dict, regressor):
    # Train a Random Forest model
    X_train = processed_dict['train']['fingerprints']
    y_train = processed_dict['train']['y']   
    regressor.fit(X_train, y_train)

    # Make predictions on the test set
    X_test = processed_dict['test']['fingerprints']
    y_test = processed_dict['test']['y']
    y_hat_test = regressor.predict(X_test)
    y_hat_test_stdev = np.std([tree.predict(X_test) for tree in regressor.estimators_], axis=0)

    return y_train, y_test, y_hat_test, y_hat_test_stdev

def rankrefine_bayes_with_oracle_ranker(y_train, n_samples, prediction_values, prediction_uncertainties, prediction_labels, ranker_accuracy):
    sampled_train_indices = random.sample(range(len(y_train)), n_samples)
    sampled_train_labels = y_train.iloc[sampled_train_indices]

    # Create a pairwise ranking matrix of size n_prediction x n_samples
    is_test_property_higher_than_train = prediction_labels.values[:, None] > sampled_train_labels.values

    # Randomly flip the labels based on ranker accuracy
    is_test_property_higher_than_train = np.where(np.random.rand(*is_test_property_higher_than_train.shape) < ranker_accuracy, is_test_property_higher_than_train, ~is_test_property_higher_than_train)

    y_refined_list = []
    y_rank_estimate_list = []

    for test_id in range(len(prediction_labels)):
        y_is_higher_than = [sampled_train_labels.iloc[i] for i in range(len(sampled_train_labels)) if is_test_property_higher_than_train[test_id, i]]
        y_is_lower_than = [sampled_train_labels.iloc[i] for i in range(len(sampled_train_labels)) if not is_test_property_higher_than_train[test_id, i]]

        y_refined, y_rank_estimate, _ = rankrefine_bayes(y_reg = prediction_values[test_id], y_reg_sigma = prediction_uncertainties[test_id], y_is_higher_than = y_is_higher_than, y_is_lower_than = y_is_lower_than, c_scale=3.0)
        
        y_refined_list.append(y_refined)
        y_rank_estimate_list.append(y_rank_estimate)

    y_refined_list = np.array(y_refined_list)
    return y_refined_list, y_rank_estimate_list


for dataset_name in dataset_name_list:
    data = ADME(name = dataset_name)
    df = data.get_data()
    splits = data.get_split()

    train_df = splits['train']
    test_df = splits['test']
    val_df = splits['valid']

    save_dir = f"./results_rankrefine/{dataset_name}/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_name = f"{save_dir}/result"
    
    val_mae_dict = {s: {n: [] for n in k_list} for s in random_seed_list}
    test_mae_dict = {s: {n: [] for n in k_list} for s in random_seed_list}

    val_mae_ranker_dict = {s: {n: [] for n in k_list} for s in random_seed_list}
    test_mae_ranker_dict = {s: {n: [] for n in k_list} for s in random_seed_list}
   
    val_non_refined_mae_list = []
    test_non_refined_mae_list = []

    for rand_seed in random_seed_list:
        # Sample 50 rows from the training set, concat the rest to test set
        df_dict = resplit_dfs(train_df, val_df, test_df, n_samples=50, rand_seed=rand_seed)
        processed_dict = featurize_df(df_dict, dataset_name)

        # Define the model
        model = RandomForestRegressor(n_estimators=100, random_state=rand_seed, verbose=1)
        y_train, y_test, y_hat_test, y_hat_test_stdev = fit_and_pred_RF(processed_dict, model)

        # Evaluate the model
        test_mae_non_refined = mean_absolute_error(y_test, y_hat_test)

        test_non_refined_mae_list.append(test_mae_non_refined)

        # Adjust predictions based on ranker labels, ranker is x% accurate
        for _ra in ranker_accuracy_list:
            print(f'Processing Dataset: {dataset_name}, Random Seed: {rand_seed}, Ranker Accuracy: {_ra}...')
            # Sample a subset of the training set
            for n_samples in k_list:
                y_refined_test, y_ranker_test = rankrefine_bayes_with_oracle_ranker(y_train, n_samples, y_hat_test, y_hat_test_stdev, y_test, _ra)
                test_mae_dict[rand_seed][n_samples].append(mean_absolute_error(y_test, y_refined_test))

    test_non_refined_mae = np.mean(test_non_refined_mae_list)
    test_non_refined_mae_std = np.std(test_non_refined_mae_list)

    # Aggregate the results across random seeds
    avg_test_mae_dict = {n: [] for n in k_list}
    avg_test_mae_std_dict = {n: [] for n in k_list} 
    for rand_seed in random_seed_list:
        for n_samples in k_list:
            avg_test_mae_dict[n_samples].append(test_mae_dict[rand_seed][n_samples])

    for n_samples in k_list:
        avg_test_mae_std_dict[n_samples] = np.std(avg_test_mae_dict[n_samples], axis = 0) / test_non_refined_mae # Normalize by non-refined MAE for beta
        avg_test_mae_dict[n_samples] = np.mean(avg_test_mae_dict[n_samples], axis = 0) / test_non_refined_mae # Normalize by non-refined MAE for beta
    test_non_refined_mae_std = test_non_refined_mae_std / test_non_refined_mae # Normalize by non-refined MAE for beta
    test_non_refined_mae = test_non_refined_mae / test_non_refined_mae # Normalize by non-refined MAE for beta

    # Plot the results
    import matplotlib.pyplot as plt
        
    # Line plot
    mean_dicts = {'test': avg_test_mae_dict}
    std_dicts  = {'test': avg_test_mae_std_dict}
    for mode in ['test']:
        fig, ax = plt.subplots(figsize=(10, 6))

        for n_samples in k_list:
            ax.plot(ranker_accuracy_list, mean_dicts[mode][n_samples], label=f"k = {n_samples}", marker='o', markersize=4)
            ax.fill_between(ranker_accuracy_list, mean_dicts[mode][n_samples] - std_dicts[mode][n_samples], mean_dicts[mode][n_samples] + std_dicts[mode][n_samples], alpha=0.1)
        
        ax.axhline(y=test_non_refined_mae, color='r', linestyle='--', label="Non-refined MAE")
        ax.fill_betweenx([test_non_refined_mae - test_non_refined_mae_std, test_non_refined_mae + test_non_refined_mae_std], 0.5, 1.0, color='r', alpha=0.1)

        ax.set_xlabel("Ranker Accuracy")
        ax.set_xticks(ranker_accuracy_list)
        ax.set_ylabel(r"$\beta$")
        ax.legend()
        ax.grid()
        plt.tight_layout()
        plt.savefig(f"{save_name}_{mode}.png")
        plt.close()