import logging
import numpy as np
import random
import time

from scipy.optimize import linear_sum_assignment
from sklearn import preprocessing
from sklearn.metrics.pairwise import pairwise_distances
from typing import Dict, Tuple, Optional

from methods.cluster_mmr import (
    partition_clustering,
    partition_random,
    get_full_partitions,
    select_partitions_mmr,
    mmr_on_partitions
)

logger = logging.getLogger(__name__)
logging.basicConfig(encoding="utf-8", format="%(name)s: %(message)s", level=logging.DEBUG)

PACKING_ALGO_KMEANS = "packing_k_means"
PACKING_ALGO_BALANCED = "packing_balanced"
PACKING_ALGO_BALANCED_FAST = "packing_balanced_fast"
PACKING_ALGO_RANDOM = "packing_random"


def pre_process_for_three_mmr(data_info: Dict, n_clusters:int, algorithm:str) -> Dict:
    text_embeddings = data_info["text_embeddings"]

    # corpus embeddings as a single data matrix
    x_list = [item[1] for item in text_embeddings]
    x = np.array(x_list)  # 2D array (n items, d features)
    x = preprocessing.normalize(x)

    if n_clusters == 1:
        partitions = [np.arange(x.shape[0])]
        centers = np.mean(x, axis=0).reshape(1, -1)
        partition_results = {"partitions": partitions, "centers": centers}
    elif algorithm == PACKING_ALGO_KMEANS:
        partition_results = partition_clustering(x, n_clusters)
    elif algorithm == PACKING_ALGO_BALANCED:
        partition_results = partition_clustering(x, n_clusters)
        n_clusters_actual = len(partition_results["partitions"])
        # create clustering with equal number of points in each cluster
        centers = partition_results["centers"]  # 2D array (n clusters, d features)
        cluster_size = x.shape[0] // n_clusters_actual + 1
        repeated = np.repeat(centers, cluster_size, axis=0)  # (n_clusters_actual*cluster_size, d features)
        d = pairwise_distances(x, repeated, metric="euclidean")
        cluster_assignment = linear_sum_assignment(d)[1] // cluster_size
        cluster_assignment = np.array(cluster_assignment)  # 1D array: n items
        partitions = []
        centers = np.zeros((n_clusters_actual, x.shape[1]))
        for c in range(n_clusters_actual):
            idx_in_cluster_c = np.where(cluster_assignment == c)[0]
            partitions.append(idx_in_cluster_c)
            x_c = x[idx_in_cluster_c, :]
            centers[c, :] = np.mean(x_c, axis=0)
        partition_results["partitions"] = partitions
        partition_results["centers"] = centers
    elif algorithm == PACKING_ALGO_BALANCED_FAST:
        partition_results = partition_clustering(x, n_clusters)
        # create clustering with equal number of points in each cluster
        partitions = get_full_partitions(x, partition_results, list(range(n_clusters)))
        centers = np.zeros((len(partitions), x.shape[1]))
        for c, idx_in_cluster_c in enumerate(partitions):
            x_c = x[idx_in_cluster_c, :]
            centers[c, :] = np.mean(x_c, axis=0)
        partition_results["partitions"] = partitions
        partition_results["centers"] = centers
    elif algorithm == PACKING_ALGO_RANDOM:
        partition_results = partition_random(x, n_clusters)
    else:
        raise ValueError(f"Unknown algorithm: {algorithm}")

    # compute maximum cluster radius and maximum size
    max_radius = -1
    max_size = -1
    partitions = partition_results["partitions"]
    centers = partition_results["centers"]
    for c, idx_in_cluster_c in enumerate(partitions):
        x_c = x[idx_in_cluster_c, :]
        mu = centers[c, :].reshape(1, -1)
        d = pairwise_distances(x_c, mu, metric="euclidean")
        d_c = np.max(d)
        if d_c > max_radius:
            max_radius = d_c
        s_c = idx_in_cluster_c.shape[0]
        if s_c > max_size:
            max_size = s_c

    stats = {"n_clusters": n_clusters, "max_radius": max_radius, "max_size": max_size}
    return {"x": x, "partition_results": partition_results, "stats": stats}


def retrieve_with_three_level_mmr(
        data_info: Dict,
        query: str,
        k: int,
        m: int = 10,
        lambda_clusters: float = 0.9,
        lambda_points: float = 0.9,
        k_within: Optional[int] = None,
        add_max_q_to_union: bool = True
) -> Tuple:
    emb_model = data_info["emb_model"]
    corpus = data_info["corpus"]
    x = data_info["pre_proc"]["x"]
    partition_results = data_info["pre_proc"]["partition_results"]

    if k > x.shape[0]:
        raise ValueError(f"cannot select more items ({k}) than available: {x.shape[0]}")

    # obtain the quality score = negative Euclidean distance to query
    x_query = np.array(emb_model.embed_query(query))
    x_query = np.expand_dims(x_query, axis=0)  # 2D array (1 item, d features)
    x_query = preprocessing.normalize(x_query)

    d = pairwise_distances(x, x_query, metric="euclidean")  # 2D array (n items, 1)
    distances = d[:,0]  # 1D array (n items)
    max_distance = np.max(distances)
    # maximum Euclidean distance between two unit-normalized vectors is 2
    assert max_distance <= 2, f"not properly scaled, max_distance: {max_distance}"
    score = 2 - distances

    start_time = time.time()
    if len(partition_results["partitions"]) == m:
        selected_cluster_idx = list(range(m))
    elif lambda_clusters < 0:
        num_partitions = len(partition_results["partitions"])
        selected_cluster_idx = random.sample(range(num_partitions), k=m)
    else:
        selected_cluster_idx = select_partitions_mmr(score, partition_results, m, lambda_clusters)
    partitions = [partition_results["partitions"][c] for c in selected_cluster_idx]

    top_q = set(np.argpartition(score, -k)[-k:])

    selected_points_idx, stats = mmr_on_partitions(
        x=x,
        q=score,
        k=k,
        top_q=top_q,
        lambda_param=lambda_points,
        partitions=partitions,
        k_within=k_within,
        add_max_q_to_union=add_max_q_to_union
    )
    stats["core_time"] = time.time() - start_time

    return [corpus[i] for i in selected_points_idx], stats
