from models.base_ann import BaseANN
from utils.transform import Transformer

class FusedANN:
    def __init__(self, base_ann: BaseANN, alpha=10.0, beta=2.0):
        self.base_ann = base_ann
        self.alpha = alpha
        self.beta = beta

    def build_single(self, vectors, attributes):
        transformed = Transformer.single_transform(vectors, attributes, self.alpha, self.beta)
        self.base_ann.build_index(transformed)

    def query_single(self, vector, attribute, top_k=10):
        transformed_query = Transformer.single_transform(vector.reshape(1, -1), attribute.reshape(1, -1), self.alpha, self.beta)[0]
        return self.base_ann.query(transformed_query, top_k)

    def build_multi(self, vectors, attributes_list, alpha_list, beta_list):
        transformed = Transformer.multi_transform(vectors, attributes_list, alpha_list, beta_list)
        self.base_ann.build_index(transformed)

    def query_multi(self, vector, attributes_list, alpha_list, beta_list, top_k=10):
        transformed_query = Transformer.multi_transform(vector.reshape(1, -1), attributes_list, alpha_list, beta_list)[0]
        return self.base_ann.query(transformed_query, top_k)