import logging
import numpy as np
import faiss
from typing import List, Tuple
from tqdm import trange
from sklearn.neighbors import KDTree
from sklearn.cluster import KMeans
from loguru import logger
def find_positive_negative_indices(
    dataset, 
    dataset_name: str, 
    corpus_emb_1: np.ndarray, 
    corpus_emb_2: np.ndarray, 
    query_emb_1: np.ndarray, 
    query_emb_2: np.ndarray
) -> Tuple[List[Tuple[int, int, int]], Tuple[List[int], List[int], List[int]], List[int]]:
    qrel_dict = dataset.qrels
    query_id_list = list(qrel_dict.keys())
    q_p_index_list = []
    mask_list = []
    for i, query_id in enumerate(query_id_list):
        q_index = i
        pos_id_list = list(qrel_dict[query_id].keys())
        try:
            pos_index_list = [dataset.corpus_ids2index[pos_id] for pos_id in pos_id_list]
        except KeyError:
            logger.debug(f"KeyError: {pos_id_list}")
            mask_list.append(i)
            continue
        pos_index = pos_index_list[0]
        q_p_index_list.append((q_index, pos_index))
    q_p_n_index_list = []
    for q_index, p_index in q_p_index_list:
        q_p_distance_emb1 = np.linalg.norm(query_emb_1[q_index] - corpus_emb_1[p_index])
        q_p_distance_emb2 = np.linalg.norm(query_emb_2[q_index] - corpus_emb_2[p_index])
        attempt_count = 0
        while True:
            if attempt_count > 100:
                logging.warning(f"Failed to find a negative sample for {q_index} and {p_index}")
                n_index = p_index
                break
            n_index = np.random.randint(0, len(corpus_emb_1))
            q_n_distance_emb1 = np.linalg.norm(query_emb_1[q_index] - corpus_emb_1[n_index])
            q_n_distance_emb2 = np.linalg.norm(query_emb_2[q_index] - corpus_emb_2[n_index])
            if q_p_distance_emb1 > q_n_distance_emb1 or q_p_distance_emb2 > q_n_distance_emb2:
                attempt_count += 1
                continue
            else:
                break
        q_p_n_index_list.append((q_index, p_index, n_index))
    q_index_list = [q_index for q_index, p_index, n_index in q_p_n_index_list]
    p_index_list = [p_index for q_index, p_index, n_index in q_p_n_index_list]
    n_index_list = [n_index for q_index, p_index, n_index in q_p_n_index_list]
    return q_p_n_index_list, (q_index_list, p_index_list, n_index_list), mask_list
def split_data(
    corpus_size: int, 
    p_index_list: List[int], 
    d0_ratio: float = 1/3, 
    corpus_emb_2: np.ndarray = None, 
    query_emb_2: np.ndarray = None, 
    strategy: str = "random"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    total_indices = np.arange(corpus_size)
    p_index_set = set(p_index_list)
    non_p_indices = list(set(total_indices))
    p_index_mid = len(p_index_list) // 2
    p_index_list_1 = p_index_list[:p_index_mid]
    p_index_list_2 = p_index_list[p_index_mid:]
    d0_target_size = (int(corpus_size * d0_ratio) - len(p_index_set))
    d0_target_size = min(d0_target_size, len(non_p_indices))
    if strategy == "random":
        np.random.shuffle(non_p_indices)
    elif strategy == "closest_to_p":
        index = faiss.IndexFlatL2(corpus_emb_2.shape[1])
        index.add(corpus_emb_2.astype(np.float32))
        _, indices = index.search(query_emb_2.astype(np.float32), 5)
        first_5_non_p_indices = [i for i in indices.flatten() if i in non_p_indices]
        non_p_indices = first_5_non_p_indices + [_ for _ in non_p_indices if _ not in first_5_non_p_indices]
        d0 = non_p_indices[:d0_target_size]
    remaining = set(non_p_indices[d0_target_size:])
    d0 = list(set(non_p_indices[:d0_target_size]))
    d1d2_size = len(remaining) + len(p_index_list)
    d1_target_size = d1d2_size // 2
    d1 = set(p_index_list_1)
    d2 = set(p_index_list_2)
    remaining_list = list(remaining)
    np.random.shuffle(remaining_list)
    d1.update(remaining_list[:d1_target_size - len(p_index_list_1)])
    d2.update(set(remaining_list[d1_target_size - len(p_index_list_1):]))
    d0, d1, d2 = [np.array(sorted(list(part))) for part in [d0, d1, d2]]
    return d0, d1, d2
def hierarchical_kmeans_sampling(
    corpus_emb: np.ndarray, 
    p_index_list: List[int], 
    d0_ratio: float, 
    layer: int = 3, 
    branch_num: int = 3, 
    query_emb_2: np.ndarray = None, 
    corpus_emb_2: np.ndarray = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    total_indices = np.arange(len(corpus_emb))
    p_index_set = set(p_index_list)
    non_p_indices = list(set(total_indices) - p_index_set)
    p_index_mid = len(p_index_list) // 2
    p_index_list_1 = p_index_list[:p_index_mid]
    p_index_list_2 = p_index_list[p_index_mid:]
    def recursive_cluster(indices, current_layer):
        if isinstance(indices, list):
            indices = np.array(indices)
        if len(indices) <= branch_num:
            return indices, [], []
        if current_layer >= layer:
            np.random.shuffle(indices)
            d0_size = int(len(indices) * d0_ratio)
            d0 = indices[:d0_size]
            remaining = indices[d0_size:]
            mid = len(remaining) // 2
            d1, d2 = remaining[:mid], remaining[mid:]
            return d0, d1, d2
        kmeans = KMeans(n_clusters=branch_num, random_state=0)
        cluster_labels = kmeans.fit_predict(corpus_emb[indices])
        d0_list, d1_list, d2_list = [], [], []
        for cluster_id in range(branch_num):
            sub_indices = indices[cluster_labels == cluster_id]
            d0_sub, d1_sub, d2_sub = recursive_cluster(sub_indices, current_layer + 1)
            d0_list.extend(d0_sub)
            d1_list.extend(d1_sub)
            d2_list.extend(d2_sub)
        return d0_list, d1_list, d2_list
    d0_indices, d1_indices, d2_indices = recursive_cluster(total_indices, current_layer=0)
    d1_indices += p_index_list_1
    d2_indices += p_index_list_2
    return np.array(d0_indices), np.array(list(set(d1_indices))), np.array(list(set(d2_indices)))
def fps_with_kdtree(
    corpus_emb: np.ndarray, 
    p_index_list: List[int], 
    d0_ratio: float, 
    layer: int = 3, 
    branch_num: int = 3, 
    query_emb_2: np.ndarray = None, 
    corpus_emb_2: np.ndarray = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    total_indices = np.arange(len(corpus_emb))
    p_index_set = set(p_index_list)
    non_p_indices = list(set(total_indices) - p_index_set)
    p_index_mid = len(p_index_list) // 2
    p_index_list_1 = p_index_list[:p_index_mid]
    p_index_list_2 = p_index_list[p_index_mid:]
    d0_target_size = int(len(total_indices) * d0_ratio)
    d0_target_size = min(d0_target_size, len(non_p_indices))
    def fps_with_kdtree_inner(data, all_indices, sample_size):
        candidate_indices = list(all_indices)
        selected_indices = []
        first = np.random.choice(candidate_indices)
        selected_indices.append(first)
        candidate_indices.remove(first)
        candidate_array = np.array(candidate_indices)
        distances = np.linalg.norm(data[candidate_array] - data[first], axis=1)
        for i in trange(1, sample_size):
            if len(candidate_indices) == 0:
                break
            farthest_idx = np.argmax(distances)
            new_candidate = candidate_indices[farthest_idx]
            selected_indices.append(new_candidate)
            candidate_indices.remove(new_candidate)
            if candidate_indices:
                candidate_array = np.array(candidate_indices)
                tree = KDTree(data[selected_indices])
                dists, _ = tree.query(data[candidate_array], k=1)
                distances = dists.flatten()
        d0_indices = np.array(selected_indices)
        remaining = np.array(candidate_indices)
        if len(remaining) > 0:
            np.random.shuffle(remaining)
            mid = len(remaining) // 2
            d1_indices = remaining[:mid]
            d2_indices = remaining[mid:]
        else:
            d1_indices = np.array([])
            d2_indices = np.array([])
        return d0_indices, d1_indices, d2_indices
    d0_indices, d1_indices, d2_indices = fps_with_kdtree_inner(corpus_emb, non_p_indices, d0_target_size)
    d1_indices = np.concatenate([d1_indices, p_index_list_1])
    d2_indices = np.concatenate([d2_indices, p_index_list_2])
    return np.array(d0_indices), np.array(list(set(d1_indices))), np.array(list(set(d2_indices)))
