import gc
import random

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from sklearn.metrics import f1_score
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler

import cfg


def sample(df, sampling_method, iteration, original_model, y_pred, y_true, x_data_original, nr_samples):
    # initial training set already exists
    if iteration == 0:
        #df = get_validation_data(df, iteration, y_true)
        return df

    # get last and next tag col
    last_tag_col = cfg.get_iteration_col(iteration - 1)
    tag_col = cfg.get_iteration_col(iteration)

    # prepare next tag_col
    df[tag_col] = df[last_tag_col]
    df[tag_col] = df[tag_col].replace(cfg.tag_validate, cfg.tag_train)

    # get nr of samples to annotate
    nr_samp_method = np.random.binomial(n=nr_samples, p=0.95)
    nr_samp_random = nr_samples - nr_samp_method

    # clone model
    model = tf.keras.models.clone_model(original_model)
    model.set_weights(original_model.get_weights())

    if sampling_method == 'random' or (nr_samp_method+nr_samp_random) >= df[tag_col].value_counts()[cfg.tag_unlabelled]:
        df = sampling_random(df, tag_col, nr_samp_method)
    elif sampling_method == 'ratio_max':
        df = sampling_ratio_max(y_pred, df, tag_col, nr_samp_method)
    elif sampling_method == 'badge':
        df = sampling_badge(model, x_data_original, y_pred, df, tag_col, nr_samp_method)
    elif sampling_method == 'bald':
        df = sampling_bald(model, x_data_original, y_true, df, tag_col, nr_samp_method)
    elif sampling_method == 'multilabel_simple_crw':
        df = sampling_multilabel_simple_crw(y_true, y_pred, df, tag_col, nr_samp_method, iteration)
    elif sampling_method == 'kmeans':
        df = sampling_kmeans(x_data_original, df, tag_col, nr_samp_method)
    elif sampling_method == 'beal':
        df = sampling_beal(model, x_data_original, y_true, df, tag_col, nr_samp_method)

    # add percentage random sampling
    if nr_samp_random != 0:
        df = sampling_random(df, tag_col, nr_samp_random)

    # create validation set
    #df = get_validation_data(df, iteration, y_true)

    # clear tensorflow session
    del model
    tf.keras.backend.clear_session()
    tf.compat.v1.reset_default_graph()
    gc.collect()
    return df


def sampling_random(df, tag_col, nr_samples):
    # filter unlabelled rows
    df_unlabelled = df[df[tag_col] == cfg.tag_unlabelled]

    # get random indices
    indices = random.sample(list(df_unlabelled.index), min(len(df_unlabelled), nr_samples))

    # change unlabelled elements in df
    df.loc[indices, tag_col] = cfg.tag_train
    return df


def sampling_ratio_max(y_pred, df, tag_col, nr_samp_method):
    y_uncertainty_array = 1 / (0.5 + np.abs(y_pred - 0.5)) - 1

    # get max score
    y_uncertainty_score = np.max(y_uncertainty_array, axis=1)

    # get indices with highest uncertainty score and label them (assign train tag)
    indices_not_unlabelled = df[df[tag_col] != cfg.tag_unlabelled].index
    y_uncertainty_score[indices_not_unlabelled] = -1  # (computed score always in range [0,1])
    high_score_indices = np.argsort(-y_uncertainty_score)[:nr_samp_method]

    # if (for a bug reason) some other data than unlabelled data was selected, select all unlabelled data for training
    if (df.loc[high_score_indices, tag_col] != cfg.tag_unlabelled).any():
        df[df[tag_col] == cfg.tag_unlabelled] = cfg.tag_train
    else:
        df.loc[high_score_indices, tag_col] = cfg.tag_train

    return df


def sampling_badge(model, x_data_original, y_pred, df, tag_col, nr_samp_method):
    # use only the unlabelled part of the data
    indices_unlabelled = df[df[tag_col] == cfg.tag_unlabelled].index
    df_unlabelled = df.iloc[indices_unlabelled]
    x_data_original = x_data_original[indices_unlabelled, :]
    y_pred = y_pred[indices_unlabelled, :]

    # 1. Compute hypothetical label
    y_hypo = (y_pred > 0.5).astype(int)

    # 2. Compute gradient embedding for each sample
    # Convert data to tensors
    x_data_tensor = tf.convert_to_tensor(x_data_original)
    y_hypo_tensor = tf.convert_to_tensor(y_hypo)

    # Model function to compute gradients for each example
    def model_fn(arg):
        inputs, labels = arg
        inputs = tf.expand_dims(inputs, axis=0)
        labels = tf.expand_dims(labels, axis=0)
        with tf.GradientTape() as tape:
            tape.watch(x_data_tensor)  # Watch the input data
            y_pred_tensor = model(inputs)  # Forward pass through the model
            loss = tf.keras.losses.binary_crossentropy(labels, y_pred_tensor)
        return tape.gradient(loss, model.layers[-1].trainable_variables)

    # Compute per-example gradients using vectorized map
    per_example_gradients = tf.vectorized_map(model_fn, (x_data_tensor, y_hypo_tensor))
    # flatten and save embeddings
    flat_gradients = [tf.reshape(grad, [grad.shape[0], -1]) for grad in per_example_gradients]
    gradient_embeddings = tf.concat(flat_gradients, axis=-1).numpy()

    # 3. Select samples based on the k-Means++ seeding algorithm
    indices_pool = list(range(np.shape(gradient_embeddings)[0]))
    indices_selected = []
    embeddings_selected = []
    # add samples iteratively
    for _ in range(nr_samp_method):
        if not indices_selected:
            # select first index randomly (uniform distribution)
            index_selected_iter = np.random.choice(indices_pool)
        else:
            # select next sample based on the distance in gradient_embedding space to already selected samples
            embeddings_selected_numpy = np.array(embeddings_selected)
            l2_norm = np.linalg.norm(gradient_embeddings[:, np.newaxis] - embeddings_selected_numpy, axis=2)
            l2_norm_min = np.min(l2_norm, axis=1)
            probabilities = l2_norm_min ** 2 / np.sum(l2_norm_min ** 2)
            index_selected_iter = np.random.choice(len(l2_norm_min), p=probabilities)

        # save selected index, add selected embedding to embeddings_selected
        indices_selected.append(index_selected_iter)
        embeddings_selected.append(gradient_embeddings[index_selected_iter])

    # add selected samples to training set
    indices_df = df_unlabelled.iloc[indices_selected].index
    df.loc[indices_df, tag_col] = cfg.tag_train
    return df


def sampling_bald(model, x_data_original, y_true, df, tag_col, nr_samp_method):
    # approximate p(w) with k models
    k = 10

    # get training and validation data
    array_tag = df[tag_col].to_numpy()
    x_train = x_data_original[array_tag == cfg.tag_train]
    y_train = y_true[array_tag == cfg.tag_train]

    # use only the unlabelled part of the data
    indices_unlabelled = df[df[tag_col] == cfg.tag_unlabelled].index
    df_unlabelled = df.iloc[indices_unlabelled]
    x_unlabelled = x_data_original[indices_unlabelled, :]

    # 1. compute the predictions for all k models (clip to avoid log problems)
    probs = get_n_predictions(k, x_train, y_train, model, x_unlabelled)
    probs = np.clip(probs, 1e-15, 1 - 1e-15)

    # 1. mean predictions, 2. compute entropy
    pb = np.mean(probs, axis=0)
    H_y_cond_x_Dtrain = -np.sum(pb * np.log(pb), axis=1)

    # 1. compute entropy, 2. mean entropies
    H_y_cond_x_theta = -np.sum(probs * np.log(probs), axis=2)
    E_H_y_cond_x_theta = np.mean(H_y_cond_x_theta, axis=0)

    # compute BALD score
    BALD_scores = H_y_cond_x_Dtrain - E_H_y_cond_x_theta

    # get the indices
    indices_selected = np.argsort(BALD_scores)[-nr_samp_method:]
    # add selected samples to training set
    indices_df = df_unlabelled.iloc[indices_selected].index
    df.loc[indices_df, tag_col] = cfg.tag_train
    return df


def sampling_multilabel_simple_crw(y_true, y_pred, df, tag_col, nr_samp_method, iteration):
    # compute uncertainty score, only use scores for unlabelled samples
    y_uncertainty_array = 1 / (0.5 + np.abs(y_pred - 0.5)) - 1
    y_uncertainty_array[df[tag_col] != cfg.tag_unlabelled, :] = -np.inf

    # compute the f1 score for each class on the evaluation data from the last iteration
    last_tag_col = cfg.get_iteration_col(iteration-1)
    indices_validate = df[df[last_tag_col] == cfg.tag_validate].index
    y_true_validate = y_true[indices_validate]
    y_pred_validate = (y_pred[indices_validate] >= 0.5).astype(int)
    f1_scores = np.array([f1_score(y_true_validate[:, i], y_pred_validate[:, i], zero_division=1.0)
                          for i in range(y_true_validate.shape[1])])
    f1_scores = np.clip(f1_scores, 0.01, 0.99)

    # compute how many samples are sampled based on which class score
    class_weights = (1-f1_scores) / np.sum(1-f1_scores)
    class_samples = []
    for index in range(len(class_weights)):
        if index == 0:
            class_samples.append(int(nr_samp_method * class_weights[index]))
        elif index == len(class_weights) - 1:
            class_samples.append(max(0, nr_samp_method - np.sum(class_samples)))
        else:
            class_samples.append(min(int(nr_samp_method * class_weights[index]), nr_samp_method - np.sum(class_samples)))

    # round robin based on the weights
    indices_df = np.array([], dtype=int)
    for index in range(len(class_samples)):
        nr_samp_class = class_samples[index]
        if nr_samp_class > 0:
            max_indices = np.argsort(y_uncertainty_array[:, index])[-nr_samp_class:]
            indices_df = np.concatenate((indices_df, max_indices))
            y_uncertainty_array[max_indices, :] = -np.inf

    # add training samples
    df.loc[indices_df, tag_col] = cfg.tag_train
    return df


def sampling_kmeans(x_data_original, df, tag_col, nr_samp_method):
    # use only the unlabelled part of the data
    indices_unlabelled = df[df[tag_col] == cfg.tag_unlabelled].index
    df_unlabelled = df.iloc[indices_unlabelled]
    x_unlabelled = x_data_original[indices_unlabelled, :]

    # scale the data
    scaler = StandardScaler()
    x_unlabelled_scaled = scaler.fit_transform(x_unlabelled)

    # principal component analysis with n_components components
    n_components = 5
    pca = PCA(n_components=n_components)
    x_unlabelled_pca = pca.fit_transform(x_unlabelled_scaled)

    # kmeans with as many clusters as samples to select
    n_clusters = nr_samp_method
    kmeans = KMeans(n_clusters=n_clusters, n_init=1)
    kmeans.fit(x_unlabelled_pca)

    # Calculate the distance between each sample and the cluster centers
    distances = kmeans.transform(x_unlabelled_pca)

    # centroids: Find the index of the sample nearest to each cluster center
    centroid_indices = distances.argmin(axis=0)

    # add selected samples to training set
    indices_df = df_unlabelled.iloc[centroid_indices].index
    df.loc[indices_df, tag_col] = cfg.tag_train

    # if samples were selected multiple times, select randomly other samples
    if nr_samp_method > len(np.unique(centroid_indices)):
        sampling_random(df, tag_col, nr_samp_method - len(np.unique(centroid_indices)))

    return df


def sampling_beal(model, x_data_original, y_true, df, tag_col, nr_samp_method):
    # approximate p(w) with t models
    t = 10

    # get training data
    array_tag = df[tag_col].to_numpy()
    x_train = x_data_original[array_tag == cfg.tag_train]
    y_train = y_true[array_tag == cfg.tag_train]

    # use only the unlabelled part of the data
    indices_unlabelled = df[df[tag_col] == cfg.tag_unlabelled].index
    df_unlabelled = df.iloc[indices_unlabelled]
    x_unlabelled = x_data_original[indices_unlabelled, :]

    # 1. compute the predictions for all k models (clip to avoid log problems)
    probs = get_n_predictions(t, x_train, y_train, model, x_unlabelled)

    # 2. compute the approximated posterior relevance score for each sample and each label
    R_Y_j = 2 ** probs - 1
    R_Y_j_mean = np.mean(R_Y_j, axis=0)

    # 3. compute sorted list by the approximated posterior relevance score
    pi_hat = R_Y_j_mean.shape[0] - np.argsort(np.argsort(R_Y_j_mean, axis=0), axis=0)

    # 4. compute the expected confidence over all models for each sample
    pi_hat_broadcast = np.tile(pi_hat, (R_Y_j.shape[0], 1, 1))
    conf_pihat_y = np.sum(R_Y_j / pi_hat_broadcast, axis=2)
    EC_X = np.mean(conf_pihat_y, axis=0)

    # 5. sample nr_samp_method samples with the lowest expected confidence score
    lowest_EC_indices = np.argsort(EC_X)[:nr_samp_method]
    # add selected samples to training set
    indices_df = df_unlabelled.iloc[lowest_EC_indices].index
    df.loc[indices_df, tag_col] = cfg.tag_train
    return df


def get_n_predictions(n, x_train_numpy, y_train_numpy, model, x_test_numpy):
    # init models list
    models = [model]

    # convert data to tf inputs
    x_train = tf.convert_to_tensor(x_train_numpy, dtype=tf.float32)
    y_train = tf.convert_to_tensor(y_train_numpy, dtype=tf.int32)
    x_test = tf.convert_to_tensor(x_test_numpy, dtype=tf.float32)

    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True,
                                                      min_delta=0.1, mode='auto', verbose=0)
    for _ in range(n - 1):
        model_iter = tf.keras.models.clone_model(model)
        model_iter.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
        model_iter.fit(x_train, y_train, epochs=1000, batch_size=64, shuffle=True, validation_split=0.2,
                       callbacks=[early_stopping], verbose=0)
        models.append(model_iter)
        del model_iter

    probs = np.array([model(x_test) for model in models])

    # clear tensorflow session
    del x_train, y_train, x_test, models
    tf.keras.backend.clear_session()
    tf.compat.v1.reset_default_graph()
    gc.collect()

    return probs


"""
WORK IN PROGRESS FOR MULTI-LABEL CLASSIFYERS
def sampling_batchbald(model, x_data_original, y_pred, y_true, df, tag_col, nr_samp_method):
    # approximate p(w) with k models
    k = 10
    models = [model]
    # get training and validation data
    array_tag = df[tag_col].to_numpy()
    x_train = x_data_original[array_tag == cfg.tag_train]
    y_train = y_true[array_tag == cfg.tag_train]
    # convert data to tf inputs
    x_train = tf.convert_to_tensor(x_train, dtype=tf.float32)
    y_train = tf.convert_to_tensor(y_train, dtype=tf.int32)
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True,
                                                      min_delta=0.1, mode='auto', verbose=0)
    for _ in range(k-1):
        model_iter = tf.keras.models.clone_model(model)
        model_iter.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
        model_iter.fit(x_train, y_train, epochs=1000, batch_size=64, shuffle=True, validation_split=0.2,
                  callbacks=[early_stopping], verbose=0)
        models.append(model_iter)
        print(k)

    # use only the unlabelled part of the data
    indices_unlabelled = df[df[tag_col] == cfg.tag_unlabelled].index
    df_unlabelled = df.iloc[indices_unlabelled]
    x_unlabelled = x_data_original[indices_unlabelled, :]
    y_unlabelled = y_true[indices_unlabelled, :]

    # init list selected samples
    indices_selected = []
    # get predictions from k models
    predictions = np.array([model.predict(x_unlabelled) for model in models])
    predictions = np.clip(predictions, 1e-15, 1-1e-15)
    # compute the entropy for all predictions of all models
    y_unlabelled_array = np.tile(y_unlabelled, (k, 1, 1))
    entropy = -y_unlabelled_array * np.log(predictions) - (1 - y_unlabelled_array) * np.log(1 - predictions)
    # compute the mean entropy across k models
    entropy_mean_models = np.mean(entropy, axis=0)
    # use the maximum value
    for _ in range(nr_samp_method):
        # compute E_p(w) [H(y_1, ..., y_n | w)]
        if indices_selected:
            mean_entropy_sampled_set = np.mean(entropy_mean_models[indices_selected], axis=0)
            E_pw_H = entropy_mean_models + mean_entropy_sampled_set
        else:
            E_pw_H = entropy_mean_models

        # compute H(y_1, ..., y_n)
        #ITERATE OVER DIFFERENT POSSIBLE Y_TRUES
        if indices_selected:
            m = min(len(indices_selected), 100)
            for m_iter in range(m):
                P_1_n_minus_1 =
                P_n_T = 

            predictions_sampled_set = predictions[:, indices_selected, :]
            predictions_sampled_set_multiplied = np.prod(predictions_sampled_set, axis=1)
            p_y_1_to_n = predictions * np.expand_dims(predictions_sampled_set_multiplied, axis=1)
        else:
            p_y_1_to_n = predictions
        H_y = -1/k * np.sum(p_y_1_to_n, axis=0) * np.log(1/k * np.sum(p_y_1_to_n, axis=0))
    return df
"""


def get_validation_data(df, iteration, y_true, validation_split=0.2):
    tag_col = cfg.get_iteration_col(iteration)

    # filter training rows (Tag df and y array)
    df_training = df[df[tag_col] == cfg.tag_train]
    y_true_training = y_true[df[tag_col] == cfg.tag_train]

    # try stratified sampling, if not possible do random sampling
    try:
        _, val_df = train_test_split(df_training,
                                     test_size=validation_split,
                                     stratify=y_true_training,
                                     random_state=1)
    except:
        val_df = df_training.sample(frac=validation_split)

    # change label from training to validation
    df.loc[val_df.index, tag_col] = cfg.tag_validate

    return df