from typing import Union
from datasets import Dataset
import numpy as np

from .al_strategy_utils import filter_by_metric, filter_by_uncertainty
from ..utils.transformers_dataset import TransformersDataset
from ..utils.get_embeddings import get_embeddings
from sklearn.preprocessing import MinMaxScaler

from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoModelForMaskedLM,
)

def normalize(arr):
    return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))

def huds(
    model,
    X_pool: Union[np.ndarray, Dataset, TransformersDataset],
    n_instances: int,
    **kwargs,
):
    # Filtering part begin
    filtering_mode = kwargs.get("filtering_mode", None)
    scores_file_name = kwargs.get("scores_file_name", None)
    weighting = float(kwargs.get("weighting", 0.5))
    print("Weighting: ", weighting)
    uncertainty_threshold = kwargs.get("uncertainty_threshold", None)
    uncertainty_mode = kwargs.get(
        "uncertainty_mode", "absolute"
    )
    embeddings_model_name = kwargs.get("embeddings_model_name", "bert-base-uncased")
    generate_output = model.generate(X_pool, to_numpy=True)
    scores = generate_output["sequences_scores"]
    sequences_ids = generate_output["sequences"]
    uncertainty_estimates = -scores
    uncertainty_estimates = np.array(uncertainty_estimates)
    uncertainty_estimates = (uncertainty_estimates - np.min(uncertainty_estimates)) / (np.max(uncertainty_estimates) - np.min(uncertainty_estimates))

    # We first get the embeddings
    device="cuda"
    embeddings_model = AutoModel.from_pretrained(embeddings_model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(embeddings_model_name)
    if tokenizer.model_max_length > 5000:
        tokenizer.model_max_length = 512

    all_data = X_pool

    print("NMT: Getting embeddings")
    embeddings = get_embeddings(
        embeddings_model,
        all_data,
        prepare_model=False,
        data_is_tokenized=False,
        tokenizer=tokenizer,
        use_averaging=True,
        use_automodel=False,
        to_numpy=True,
        task="nmt",
        text_name="en",
        label_name="de",
    )

    # Now we perform stratified sampling on uncertainty scores
    print("NMT: Performing stratified sampling")
    index_to_scores = []
    index_to_scores_dict = {}
    for i in range(len(uncertainty_estimates)):
        index_to_scores.append((i, uncertainty_estimates[i]))
        index_to_scores_dict[i] = uncertainty_estimates[i]

    index_to_scores_sorted = sorted(index_to_scores, key=lambda x: x[1], reverse=True)

    print("Index to scores sorted")

    highest_uncertainty = index_to_scores_sorted[0][1]
    lowest_uncertainty = index_to_scores_sorted[-1][1]
    num_buckets = 10

    buckets = []
    for i in range(1, num_buckets):
        bucket = []
        lower_limit = lowest_uncertainty + (
            (i - 1) * (highest_uncertainty - lowest_uncertainty) / num_buckets
        )
        upper_limit = lowest_uncertainty + ((i) * (highest_uncertainty - lowest_uncertainty) / num_buckets)
        for score in index_to_scores_sorted:
            if score[1] >= lower_limit and score[1] < upper_limit:
                bucket.append(score)
        buckets.append(bucket)

    print("Buckets")

    from sklearn.cluster import KMeans

    print("NMT: Performing kmeans clustering with weighted uncertainty")
    print("With normalized distances")
    distances_from_cluster_centroid = np.zeros(len(index_to_scores))
    diversity_scores = np.zeros(len(index_to_scores))
    for bucket in buckets:
        if len(bucket) == 0:
            continue
        bucket_indices = [score[0] for score in bucket]
        bucket_embeddings = []
        for idx in bucket_indices:
            bucket_embeddings.append(embeddings[idx])
        bucket_embeddings = np.array(bucket_embeddings)
        kmeans = KMeans(n_clusters=1).fit(bucket_embeddings)
        distances = kmeans.transform(bucket_embeddings)
        distances = np.array(distances.flatten().tolist())
        scaler = MinMaxScaler()
        distances = scaler.fit_transform(distances.reshape(-1,1))
        for idx, distance in zip(bucket_indices, distances):
            distances_from_cluster_centroid[idx] = (weighting * distance) + ((1 - weighting) * index_to_scores_dict[idx])
            diversity_scores[idx] = distance

    uncertainty_estimates = distances_from_cluster_centroid

    # np.savez('/home/jovyan/al_toolbox/scores/' + scores_file_name, array1=distances_from_cluster_centroid, array2=diversity_scores, array3=index_to_scores)

    print("uncertainty estimates")

    # sort in descending order
    # most uncertain are at the beginning of the array
    argsort = np.argsort(-uncertainty_estimates)
    query_idx = argsort[:n_instances]

    query = X_pool.select(query_idx)

    return query_idx, query, uncertainty_estimates
