import numpy as np
import time
from models.fused_ann import FusedANN
from models.base_ann import BaseANN
from datasets.dataset_loader import DatasetLoader

def run_ablation_study():
    vectors, attributes = DatasetLoader.load_dataset("SIFT1M")
    query_vectors, query_attrs = DatasetLoader.load_dataset("SIFT1M_query")

    alpha_beta_pairs = [(10,2), (20,2), (10,4), (20,4)]
    for alpha, beta in alpha_beta_pairs:
        ann_algorithm = SimpleANN()
        fused_ann = FusedANN(ann_algorithm, alpha, beta)
        fused_ann.build_single(vectors, attributes)

        recalls = []
        for i in range(len(query_vectors)):
            results = fused_ann.query_single(query_vectors[i], query_attrs[i], top_k=10)
            true_neighbors = np.arange(10)
            recall = len(set(results).intersection(true_neighbors)) / 10
            recalls.append(recall)

        print(f"Ablation (alpha={alpha}, beta={beta}) - Recall@10: {np.mean(recalls)}")

if __name__ == "__main__":
    run_ablation_study()