import os
import numpy as np
import torch
from typing import Tuple, List, Optional, Dict, Any
from loguru import logger
from pathlib import Path
import argparse
from src.embeddings.data_types import EmbeddingDataset, EmbeddingPair
def find_positive_negative_indices(dataset, dataset_name, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, query_emb_1: np.ndarray, query_emb_2: np.ndarray) -> List[Tuple[int, int, int]]:
    qrel_dict = dataset.qrels
    query_id_list = list(qrel_dict.keys())
    q_p_index_list = []
    mask_list = []
    for i, query_id in enumerate(query_id_list):
        q_index = i
        pos_id_list = list(qrel_dict[query_id].keys())
        try:
            pos_index_list = [dataset.corpus_ids2index[pos_id] for pos_id in pos_id_list]
        except KeyError:
            logger.debug(f"KeyError: {pos_id_list}")
            mask_list.append(i)
            continue
        pos_index = pos_index_list[0]
        q_p_index_list.append((q_index, pos_index))
    q_p_n_index_list = []
    for q_index, p_index in q_p_index_list:
        q_p_distance_emb1 = np.linalg.norm(query_emb_1[q_index] - corpus_emb_1[p_index])
        q_p_distance_emb2 = np.linalg.norm(query_emb_2[q_index] - corpus_emb_2[p_index])
        attempt_count = 0
        while True:
            if attempt_count > 100:
                logger.warning(f"Failed to find a negative sample for {q_index} and {p_index}")
                n_index = p_index
                break
            n_index = np.random.randint(0, len(corpus_emb_1))
            q_n_distance_emb1 = np.linalg.norm(query_emb_1[q_index] - corpus_emb_1[n_index])
            q_n_distance_emb2 = np.linalg.norm(query_emb_2[q_index] - corpus_emb_2[n_index])
            if q_p_distance_emb1 > q_n_distance_emb1 or q_p_distance_emb2 > q_n_distance_emb2:
                attempt_count += 1
                continue
            else:
                break
        q_p_n_index_list.append((q_index, p_index, n_index))
    q_index_list = [q_index for q_index, p_index, n_index in q_p_n_index_list]
    p_index_list = [p_index for q_index, p_index, n_index in q_p_n_index_list]
    n_index_list = [n_index for q_index, p_index, n_index in q_p_n_index_list]
    return q_p_n_index_list, (q_index_list, p_index_list, n_index_list), mask_list
def preprocess_embeddings_and_indices(args: argparse.Namespace, dataset: EmbeddingDataset) -> Tuple[EmbeddingPair, List[Tuple[int, int]]]:
    q_p_n_index_list, (q_index_list, p_index_list, n_index_list), mask_list = find_positive_negative_indices(
        dataset.metadata, args.test_dataset, dataset.pair.corpus_emb_1, dataset.pair.corpus_emb_2, dataset.pair.query_emb_1, dataset.pair.query_emb_2)
    if len(mask_list) > 0:
        logger.warning(f"remove {len(mask_list)} queries from query_emb_1 and query_emb_2")
        select_index = [i for i in range(len(dataset.test.query_emb_1)) if i not in mask_list]
        dataset.pair.query_emb_1 = dataset.pair.query_emb_1[select_index]
        dataset.pair.query_emb_2 = dataset.pair.query_emb_2[select_index]
    dataset.pair.p_index_list = p_index_list
    return dataset
