import numpy as np
import faiss
import faiss.contrib.torch_utils
from prettytable import PrettyTable

def visualize_matches(query_path, ref_path, dataset_name, idx, matched=True, output_dir='./match_results'):
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    from PIL import Image
    import os

    query_img = Image.open(query_path)

    ref_img = Image.open(ref_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    save_path = f'{output_dir}/{dataset_name}_query_{idx}_{"matched" if matched else "not_matched"}.png'

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(query_img)
    axes[0].set_title('Query Image')
    axes[0].axis('off')

    axes[1].imshow(ref_img)
    axes[1].set_title('Reference Image')
    axes[1].axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    

def get_validation_recalls(r_list, q_list, r_paths, q_paths, k_values, gt, print_results=True, faiss_gpu=False, dataset_name='dataset without name ?', testing=False):
        
        embed_size = r_list.shape[1]
        if faiss_gpu:
            res = faiss.StandardGpuResources()
            flat_config = faiss.GpuIndexFlatConfig()
            flat_config.useFloat16 = True
            flat_config.device = 0
            faiss_index = faiss.GpuIndexFlatL2(res, embed_size, flat_config)
        # build index
        else:
            faiss_index = faiss.IndexFlatL2(embed_size)
        
        # add references
        faiss_index.add(r_list)

        # search for queries in the index
        _, predictions = faiss_index.search(q_list, max(k_values))
        if testing:
            return predictions
        
        # # start calculating recall_at_k
        # correct_at_k = np.zeros(len(k_values))
        # for q_idx, pred in enumerate(predictions):
        #     for i, n in enumerate(k_values):
        #         # if in top N then also in top NN, where NN > N
        #         if np.any(np.in1d(pred[:n], gt[q_idx])):
        #             correct_at_k[i:] += 1
        #             # visualize_matches(q_paths[q_idx], r_paths[gt[q_idx][0]], dataset_name, q_idx)
        #             break
        #         # elif i == len(k_values) - 1:
        #             # visualize_matches(q_paths[q_idx], r_paths[pred[0]], dataset_name, q_idx, False)

        # start calculating recall_at_k
        correct_at_k = np.zeros(len(k_values))
        for q_idx, pred in enumerate(predictions):
            for i, n in enumerate(k_values):
                # if in top N then also in top NN, where NN > N
                if np.any(np.in1d(pred[:n], gt[q_idx])):
                    correct_at_k[i:] += 1
                    # if n > 1:
                    #     with open("visualize_boq.txt", "a") as f:
                    #         f.write(f"{dataset_name}, {q_paths[q_idx]}, {r_paths[pred[0]]}, {n}\n")
                    break
                # elif i == len(k_values) - 1:
                #     # If the correct answer is not found, log the details to "visualize.txt"
                #     with open("visualize_boq.txt", "a") as f:
                #         f.write(f"{dataset_name}, {q_paths[q_idx]}, {r_paths[pred[0]]}, 1000\n")
        
        correct_at_k = correct_at_k / len(predictions)
        d = {k:v for (k,v) in zip(k_values, correct_at_k)}


        if print_results:
            print() # print a new line
            table = PrettyTable()
            table.field_names = ['K']+[str(k) for k in k_values]
            table.add_row(['Recall@K']+ [f'{100*v:.2f}' for v in correct_at_k])
            print(table.get_string(title=f"Performances on {dataset_name}"))
        
        return d

