"""
Usage
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Input
x: dataset, 2D array (n items, d features)
q: quality scores, 1D array (n items)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0. Partition the data
First, data normalization is recommended.
Next, use either partition_random(...) or partition_clustering(...).
If data is large, sample data before clustering.
This will return <partition_results> (Dict) data structure.

1. Select partitions (clusters) using MMR
select_partitions_mmr(q, partition_results, num_partitions_to_select, lambda_for_partitions)

2. Finalize partitions
get_full_partitions(x, partition_results, selected_cluster_idx, by_radius)
When Step 0 used data sample for clustering, this step is needed to form partitions using full data.

3. MMR on partitioned data
mmr_on_partitions(x, q, k, k_local, lambda_param, partitions, n_jobs)
Uses MMR to select points within each selected partition.
Takes a union of selected points.
Applies MMR to produce final result.
"""
import logging
import numpy as np
import time

from joblib import Parallel, delayed
from typing import Dict, List, Set, Tuple, Optional
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import pairwise_distances
import faiss

logger = logging.getLogger(__name__)
logging.basicConfig(encoding="utf-8", format="%(name)s: %(message)s", level=logging.DEBUG)


MMR_DIVERSITY_TYPE_SUM = "sum"
MMR_DIVERSITY_TYPE_MIN = "min"


def partition_clustering(
        x: np.ndarray,
        n_partitions: int,
        batch_size: int = 1024,
        random_state: int = 321) -> Dict:
    """
    Partition data using clustering
    
    :param x: data to cluster, 2D array (n items, d features)
    :param n_partitions: number of partitions
    :param batch_size: batch size for MiniBatchKMeans
    :param random_state: random state for MiniBatchKMeans
    :return: dictionary with partitioning results 
    """""
    kmeans = MiniBatchKMeans(n_clusters=n_partitions, batch_size=batch_size, random_state=random_state)
    kmeans.fit(x)

    partitions = []
    radii = []
    centers = kmeans.cluster_centers_
    non_empty = np.full(n_partitions, True)

    for c in range(n_partitions):
        indices_in_cluster = np.where(kmeans.labels_ == c)[0]
        if indices_in_cluster.size:
            partitions.append(indices_in_cluster)
            x_c = x[indices_in_cluster, :]
            mu_c = centers[c, :].reshape(1, -1)  # reshape to 2D array
            d = pairwise_distances(x_c, mu_c, metric="euclidean")
            radii.append(np.max(d))
        else:
            non_empty[c] = False

    if len(partitions) < n_partitions:
        logger.warning(f"empty clusters: {np.where(~non_empty)[0]}")
        centers = centers[non_empty]

    result = {
        "partitions": partitions,  # each element is a 1D array with indexes of points in the partition
        "radii": radii,  # 1D array (n partitions) with cluster radii
        "centers": centers  # 2D array (n partitions, d features)
    }
    return result


def partition_random(x: np.ndarray, n_partitions: int, random_state: int = 321) -> Dict:
    """
    Partition data randomly

    :param x: data to cluster, 2D array (n items, d features)
    :param n_partitions: number of partitions
    :param random_state: random state
    :return: dictionary with partitioning results 
    """""
    rng = np.random.default_rng(random_state)
    cluster_size = x.shape[0] // n_partitions + 1
    partition = np.arange(n_partitions)
    partition = np.repeat(partition, cluster_size)
    rng.shuffle(partition)
    partition = partition[:x.shape[0]]

    partitions = []
    centers = np.zeros((n_partitions, x.shape[1]))
    for c in range(n_partitions):
        idx_in_cluster_c = np.where(partition == c)[0]
        partitions.append(idx_in_cluster_c)
        x_c = x[idx_in_cluster_c, :]
        centers[c, :] = np.mean(x_c, axis=0)

    result = {
        "partitions": partitions,  # each element is a 1D array with indexes of points in the partition
        "centers": centers  # 2D array (n partitions, d features)
    }
    return result


def greedy_selection(
        x: np.ndarray, q: np.ndarray, k: int, weight_q: float, weight_d: float = None, diversity_type: str = MMR_DIVERSITY_TYPE_SUM
) -> List:
    """
    Greedy Selection method, similar to Maximal Marginal Relevance (MMR) selection

    :param x: data to select from, 2D array (n items, d features)
    :param q: quality scores, 1D array: n items
    :param k: number of items to select
    :param weight_q: weighting for quality (e.g., lambda in MMR)
    :param weight_d: weighting for diversity (e.g., 1-lambda in MMR), if None, weight_d = 1 - weight_q
    :param diversity_type: defining diversity using sum or min
    :return: list of selected item indices
    """
        
    if weight_d is None:
        weight_d = 1 - weight_q
    if weight_q < 0 or weight_q > 1:
        raise ValueError(f"invalid weight_q: {weight_q}, should be within [0, 1]")
    if weight_d < 0 or weight_d > 1:
        raise ValueError(f"invalid weight_q: {weight_d}, should be within [0, 1]")
    if x.shape[0] != q.shape[0]:
        raise ValueError(f"inconsistent shapes: x: {x.shape}, q: {q.shape}")
    if x.shape[0] == 0:
        raise ValueError("empty data")
    if x.shape[0] < k:
        #logger.warning(f"n_items: {x.shape[0]}, k: {k}, clipping k")
        k = x.shape[0]

    t = np.zeros_like(q)  # use to prevent selecting the same item twice

    selected_docs = []
    for i in range(k):
        if selected_docs:
            s = x[selected_docs, :]  # selected items in rows, d features in columns
            d_xs = pairwise_distances(x, s, metric="euclidean")  # (n all items) x (n selected items)
            if diversity_type == MMR_DIVERSITY_TYPE_SUM:
                diversity = np.sum(d_xs, axis=1) / (i + 1)
            elif diversity_type == MMR_DIVERSITY_TYPE_MIN:
                diversity = np.min(d_xs, axis=1)  # (n all items)
            else:
                raise ValueError(f"invalid diversity type: {diversity_type}")
            mmr_scores = t + weight_q * q + weight_d * diversity
        else:
            mmr_scores = q  # start with the item with maximum quality score

        selected_idx = np.argmax(mmr_scores)
        t[selected_idx] = -np.inf  # do not select the same item twice
        selected_docs.append(selected_idx)

    return selected_docs


def _mmr_wrapper(partition_id: int, x: np.ndarray, q: np.ndarray, k: int, lambda_param: float, diversity_type: str) -> Tuple:
    """
    Helper function for parallel MMR calls

    :param partition_id: partition ID
    :param x: data from current partition, 2D array (n items, d features)
    :param q: quality scores, 1D array: n items
    :param k: number of items to select
    :param lambda_param: balance quality and diversity
    :param diversity_type: defining diversity using sum or min
    :return: partition ID, and list of selected item indices (w.r.t. current data)
    """
    return partition_id, greedy_selection(x, q, k, weight_q = lambda_param, diversity_type = diversity_type)


def select_partitions_mmr(
        q: np.ndarray,
        partition_results: Dict,
        m: int,
        lambda_part: float = 0.9,
        diversity_type: str = MMR_DIVERSITY_TYPE_SUM
) -> List:
    """
    Use MMR to select partitions (clusters)

    :param q: quality scores, 1D array (n items)
    :param partition_results: dictionary with partitioning results
    :param m: number of partitions (clusters) to select
    :param lambda_part: balance quality and diversity in selecting partitions (clusters)
    :param diversity_type: defining diversity using sum or min
    :return: list of selected partition indices
    """
    partitions, centers = partition_results["partitions"], partition_results["centers"]
    cluster_scores = [np.median(q[partition_indices]) for partition_indices in partitions]
    selected_cluster_idx = greedy_selection(centers, np.array(cluster_scores), m, 
                                            weight_q = lambda_part, diversity_type = diversity_type)
    return selected_cluster_idx


def select_within_partitions(
    x: np.ndarray,
    q: np.ndarray,
    partitions: List,
    k: int,
    lambda_within: float = 0.9,
    n_jobs: int = 10,
    add_max_q_to_union: bool = True,
    diversity_type: str = MMR_DIVERSITY_TYPE_SUM
) -> Set:
    """
    Uses MMR to select points within each selected partition

    :param x: full dataset, 2D array: (n items, d features)
    :param q: quality scores (full data), 1D array: n items
    :param partitions: list, where each element is array of indices of items in a partition
    :param k: number of items to select
    :param lambda_within: balance quality and diversity in selecting points within each partition (cluster)
    :param n_jobs: number of MMR jobs to run in parallel
    :param add_max_q_to_union: always add the max quality point to selection
    :param diversity_type: defining diversity using sum or min
    :return: set of indices of selected items
    """
    collected_indices = set()
    results = Parallel(n_jobs=n_jobs)(
        delayed(_mmr_wrapper)(
            i,
            x[cluster_items, :],
            q[cluster_items],
            k,
            lambda_within,
            diversity_type
        )
        for i, cluster_items in enumerate(partitions)
    )
    for i, selected_items in results:
        actual_index = partitions[i][selected_items]
        collected_indices |= set(actual_index)
    if add_max_q_to_union:
        i_max_score = q.argmax()
        if i_max_score not in collected_indices:
            # ensure item with the highest quality score is added
            collected_indices.add(i_max_score)

    return collected_indices


def get_full_partitions(
        x: np.ndarray,
        partition_results: Dict,
        selected_cluster_idx: List,
        by_radius: bool = False
) -> List:
    """
    For each selected partition, find the closest points to partition centroid

    :param x: full dataset, 2D array (n items, d features)
    :param partition_results: dictionary with partitioning results
    :param selected_cluster_idx: list of selected partition indices
    :param by_radius: data selection based on partition radius VS desired size
    :return: list, each element is a 1D array with indexes of points in the partition
    """
    centers = partition_results["centers"]
    partitions = []

    if by_radius:
        # select all points within estimated partition radius
        assert "radii" in partition_results, "cannot finalize by radius without pre-computed radii"
        radii = partition_results["radii"]
        for c in selected_cluster_idx:
            mu = centers[c, :].reshape(1, -1)  # reshape to 2D array
            distances = pairwise_distances(x, mu, metric="euclidean").ravel()
            selected_indices = np.where(distances <= radii[c])[0]
            partitions.append(selected_indices)
    else:
        # select top <desired_size> points closest to the center
        desired_size = x.shape[0] // centers.shape[0] + 1
        for c in selected_cluster_idx:
            mu = centers[c, :].reshape(1, -1)  # reshape to 2D array
            distances = pairwise_distances(x, mu, metric="euclidean").ravel()
            selected_indices = np.argpartition(distances, desired_size)[:desired_size]
            partitions.append(selected_indices)

    return partitions

def mmr_on_partitions(
        x: np.ndarray,
        q: np.ndarray,
        k: int,
        top_q: Set,
        lambda_param: float,
        partitions: List,
        n_jobs: int = 10,
        k_within: Optional[int] = None,
        add_max_q_to_union: bool = True,
        diversity_type: str = MMR_DIVERSITY_TYPE_SUM
) -> Tuple:
    """
    Uses MMR to select points within each selected partition.
    Takes a union of selected points.
    Applies MMR to produce final result.

    :param x: dataset, 2D array: (n items, d features)
    :param q: quality scores, 1D array: n items
    :param k: number of items to select
    :param top_q: set of k indices of top quality items in [x]
    :param k_local: number of items to select within each cluster
    :param lambda_param: balance between quality and diversity of selection
    :param partitions: list, each element is a 1D array with indexes of points in the partition
    :param n_jobs: number of MMR jobs to run in parallel
    :param k_within: number of items to select within each partition (cluster)
    :param add_max_q_to_union: always add the max quality point to the union
    :param diversity_type: either MMR_DIVERSITY_TYPE_SUM or MMR_DIVERSITY_TYPE_MIN
    :return: list of selected item indices
    """
    if k_within is None:
        k_within = k

    start_time = time.time()
    collected_indices = select_within_partitions(x, q, partitions, k_within, 
                                                 lambda_param, n_jobs, add_max_q_to_union, diversity_type)
    time_within = time.time() - start_time

    collected_indices |= top_q
    collected_indices = list(collected_indices)
    start_time = time.time()
    final_selection = greedy_selection(x[collected_indices, :], q[collected_indices], k,
                   weight_q = lambda_param, weight_d = 1 - lambda_param, diversity_type = diversity_type)
    final_result = [collected_indices[i] for i in final_selection]
    time_union = time.time() - start_time

    diagnostics = {"time_within": time_within, "time_union": time_union, "size_union": len(collected_indices)}
    return final_result, diagnostics


def partition_clustering_by_faiss(
                                    x: np.ndarray,
                                    n_partitions:int 
) -> Dict :
    """
    Performs k-means clustering using FAISS to partition the input data.

    :param x: data to cluster, 2D array (n items, d features)
    :param n_partitions: number of partitions
    :return: dictionary with partitioning results 
    """
    
    kmeans = faiss.Kmeans(x.shape[1], n_partitions, niter=50, verbose=False, nredo=4)
    kmeans.train(x)

    # get the label assignment
    distance_to_cluster, cluster_assignment = kmeans.index.search(x, 1)

    # Organize points into clusters
    clusters = [[] for _ in range(n_partitions)]
    for idx, cluster_id in enumerate(cluster_assignment.ravel()):
        clusters[cluster_id].append(idx)
        
    for idx in range(n_partitions):
        if len(clusters[idx])>0:
            clusters[idx] = np.asarray(clusters[idx])
        else:
            clusters[idx] = []
    
    partition_results = {}
    partition_results["partitions"] = clusters
    partition_results["centers"] = kmeans.centroids

    return partition_results

def muss(
        x: np.ndarray,
        q: np.ndarray,
        k: int,
        k_within: int,
        m: int,
        n_partitions: int,
        lamb: float,
        lamb_c: float,
        n_jobs: int = 10,
        diversity_type: str = MMR_DIVERSITY_TYPE_SUM
) -> Tuple:
    """
    Run MUSS Algorithm 2 from the paper

    :param x: dataset, 2D array: (n items, d features)
    :param q: quality scores, 1D array: n items
    :param k: number of items to select
    :param k_within: number of items to select within each cluster
    :param m: number of partitions to be selected
    :param lamb: balance between quality and diversity of selection at item level
    :param lamb_c: balance between quality and diversity of selection at cluster level
    :param n_partitions: number of partitions to be clustered
    :param n_jobs: number of MMR jobs to run in parallel
    :param diversity_type: either MMR_DIVERSITY_TYPE_SUM or MMR_DIVERSITY_TYPE_MIN
    :return: list of selected item indices, diagnosis dictionary of running times
    """
    
    start_time = time.time()
    cluster_out = partition_clustering_by_faiss(x,n_partitions=n_partitions)
    clustering_time = np.round( time.time() - start_time,2)

    selected_partition = select_partitions_mmr(q, partition_results = cluster_out, m=m,
                                               lambda_part=lamb_c,diversity_type=diversity_type)
    
    selected_partition_idx =  [cluster_out["partitions"][i] for i in selected_partition]
          
    selected_idx, diagnostics = mmr_on_partitions(x,q, k=k,k_within = k_within,
                                     lambda_param=lamb, n_jobs=n_jobs, 
                                     partitions=selected_partition_idx,diversity_type=diversity_type)
    
    diagnostics["clustering time"] = clustering_time
    return selected_idx, diagnostics
