import numpy as np
import time
from api.build_index import build_fused_index
from api.query_index import query_multi_filter
from models.base_ann import BaseANN
from datasets.dataset_loader import DatasetLoader

def run_multi_attribute_experiment():
    ann_algorithm = SimpleANN()
    vectors, attrs = DatasetLoader.load_dataset("SIFT1M")
    fused_ann = build_fused_index("SIFT1M", ann_algorithm)

    query_vectors, query_attrs = DatasetLoader.load_dataset("SIFT1M_query")

    # Assume we have multiple attributes
    attributes_list = [attrs[:, :16], attrs[:, 16:32]]
    alpha_list = [10.0, 20.0]
    beta_list = [2.0, 4.0]

    recalls = []
    times = []
    for i in range(len(query_vectors)):
        query_attributes_list = [query_attrs[i, :16], query_attrs[i, 16:32]]
        start = time.time()
        results = query_multi_filter(fused_ann, query_vectors[i], query_attributes_list, alpha_list, beta_list, top_k=10)
        times.append(time.time() - start)
        true_neighbors = np.arange(10)
        recall = len(set(results).intersection(true_neighbors)) / 10
        recalls.append(recall)

    print("Multi-Attribute Filtering - Avg Recall@10:", np.mean(recalls))
    print("Multi-Attribute Filtering - Avg Query Time (ms):", np.mean(times)*1000)

if __name__ == "__main__":
    run_multi_attribute_experiment()