
import numpy as np
import torch
from torch import Tensor
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as ECFPbitVec
from rdkit.DataStructs import BulkTanimotoSimilarity
from rdkit import Chem
import math



def logits_to_pred(logits_N_K_C: Tensor, return_prob: bool = True, return_uncertainty: bool = True):
    """ Get the probabilities/class vector and sample uncertainty from the logits """

    mean_probs_N_C = torch.mean(torch.exp(logits_N_K_C), dim=1)
    uncertainty = mean_sample_entropy(logits_N_K_C)

    if return_prob:
        y_hat = mean_probs_N_C
    else:
        y_hat = torch.argmax(mean_probs_N_C, dim=1)

    if return_uncertainty:
        return y_hat, uncertainty
    else:
        return y_hat


def logit_mean(logits_N_K_C: Tensor, dim: int, keepdim: bool = False) -> Tensor:
    """ Logit mean with the logsumexp trick - Kirch et al., 2019, NeurIPS """

    return torch.logsumexp(logits_N_K_C, dim=dim, keepdim=keepdim) - math.log(logits_N_K_C.shape[dim])


def entropy(logits_N_K_C: Tensor, dim: int, keepdim: bool = False) -> Tensor:
    """Calculates the Shannon Entropy """

    return -torch.sum((torch.exp(logits_N_K_C) * logits_N_K_C).double(), dim=dim, keepdim=keepdim)


def mean_sample_entropy(logits_N_K_C: Tensor, dim: int = -1, keepdim: bool = False) -> Tensor:
    """Calculates the mean entropy for each sample given multiple ensemble predictions - Kirch et al., 2019, NeurIPS"""

    sample_entropies_N_K = entropy(logits_N_K_C, dim=dim, keepdim=keepdim)
    entropy_mean_N = torch.mean(sample_entropies_N_K, dim=1)

    return entropy_mean_N


def get_mutual_information(logits_N_K_C: Tensor) -> Tensor:
    """ Calculates the Mutual Information - Kirch et al., 2019, NeurIPS """

    # this term represents the entropy of the model prediction (high when uncertain)
    entropy_mean_N = mean_sample_entropy(logits_N_K_C)

    # This term is the expectation of the entropy of the model prediction for each draw of model parameters
    mean_entropy_N = entropy(logit_mean(logits_N_K_C, dim=1), dim=-1)

    I = mean_entropy_N - entropy_mean_N

    return I


















def uncertainty_std_low_pos(logits_N_K_C, smiles_screen, n, smiles_hit):
    logits_N_K = logits_N_K_C[:,:,1]

    std_N = torch.std(logits_N_K, dim=1)

    picks_idx = torch.argsort(std_N, descending=False)[:n]

    return smiles_screen[picks_idx.cpu()]


def uncertainty_low_pos(logits_N_K_C, smiles_screen, n, smiles_hit):
    logits_N_K = logits_N_K_C[:,:,1]

    sample_entropies_N_K = -(torch.exp(logits_N_K) * logits_N_K).double()
    entropy_mean_N = torch.mean(sample_entropies_N_K, dim=1)

    picks_idx = torch.argsort(entropy_mean_N, descending=False)[:n]

    return smiles_screen[picks_idx.cpu()]


def upper_confidence_bound(logits_N_K_C, smiles_screen, n, smiles_hit, ratio=0.2):
    logits_N_K = logits_N_K_C[:,:,1]

    mean_N = torch.mean(logits_N_K, dim=1)
    std_N = torch.std(logits_N_K, dim=1)
    ucb_N = mean_N + ratio * std_N

    picks_idx = torch.argsort(ucb_N, descending=True)[:n]

    return smiles_screen[picks_idx.cpu()]


def thompson_sampling(logits_N_K_C, smiles_screen, n, smiles_hit, seed=0):
    rng = np.random.default_rng(seed)
    k_idx = torch.tensor(rng.integers(0, 10, logits_N_K_C.shape[0]))

    logits_N_list = [logits_N_K_C[i, k_idx[i], 1] for i in range(logits_N_K_C.shape[0])]
    logits_N = torch.tensor(logits_N_list)

    picks_idx = torch.argsort(logits_N, descending=True)[:n]

    return smiles_screen[picks_idx.cpu()]











def greedy(logits_N_K_C, smiles_screen, n, smiles_hit):
    """ Get the n highest predicted samples """

    mean_probs_hits = torch.mean(torch.exp(logits_N_K_C), dim=1)[:, 1]
    picks_idx = torch.argsort(mean_probs_hits, descending=True)[:n]
    
    return smiles_screen[picks_idx.cpu()]


def uncertainty(logits_N_K_C, smiles_screen, n, smiles_hit):
    """ Get the n most samples with the most variance in hit classification """

    sample_entropies_N_K = -torch.sum((torch.exp(logits_N_K_C) * logits_N_K_C).double(), dim=-1, keepdim=False)
    entropy_mean_N = torch.mean(sample_entropies_N_K, dim=1)
    
    picks_idx = torch.argsort(entropy_mean_N, descending=True)[:n]

    return smiles_screen[picks_idx.cpu()]


def mutual_information(logits_N_K_C, smiles_screen, n, smiles_hit):
    """ Get the n molecules with the lowest Mutual Information """
    I = get_mutual_information(logits_N_K_C)

    picks_idx = torch.argsort(I, descending=False)[:n]

    return smiles_screen[picks_idx.cpu()]


def similarity(logits_N_K_C, smiles_screen, n, smiles_hit, radius=2, nBits=1024):
    """ 1. Compute the similarity of all screen smiles to all hit smiles
        2. take the n screen smiles with the highest similarity to any hit """

    fp_hits = [ECFPbitVec(Chem.MolFromSmiles(smi), radius=radius, nBits=nBits) for smi in smiles_hit]
    fp_smiles = [ECFPbitVec(Chem.MolFromSmiles(smi), radius=radius, nBits=nBits) for smi in smiles_screen]

    m = np.zeros([len(smiles_hit), len(smiles_screen)], dtype=np.float16)
    for i in range(len(smiles_hit)):
        m[i] = BulkTanimotoSimilarity(fp_hits[i], fp_smiles)

    # get the n highest similarity smiles to any hit
    picks_idx = np.argsort(np.max(m, axis=0))[::-1][:n]

    return smiles_screen[picks_idx]


random_rng = np.random.default_rng(seed=0)
def random(logits_N_K_C, smiles_screen, n, smiles_hit):
    """ select n random samples """

    picks_idx = random_rng.integers(0, len(smiles_screen), n)

    return smiles_screen[picks_idx]












acquisition_dict = {
    'greedy': greedy,
    'mi': mutual_information,
    'uncertainty': uncertainty,
    'similarity': similarity,
    'random': random,

    "uncertainty_low_pos": uncertainty_low_pos,
    "uncertainty_std_low_pos": uncertainty_std_low_pos,
    "ucb": upper_confidence_bound,
    "ts": thompson_sampling
}

def acquire(acquisition, logits_N_K_C, smiles_screen, n, smiles_hit):
    return acquisition_dict[acquisition](logits_N_K_C=logits_N_K_C, smiles_screen=smiles_screen, n=n, smiles_hit=smiles_hit)













