
import os.path
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.metrics.pairwise import cosine_similarity
from joblib import load, dump

# Load SPLADE model + tokenizer
model_name = "naver/splade-cocondenser-selfdistil"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()

def encode_splade(texts):
	inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
	with torch.no_grad():
		outputs = model(**inputs).logits  # [batch, seq_len, vocab]
	# ReLU + log(1+x) pooling
	relu = torch.relu(outputs)
	weights = torch.log1p(relu).max(dim=1).values  # [batch, vocab]
	return weights.cpu().numpy()

class SPLADE:
	doc_vecs = None


	def __init__(self, sentences, threshold=0.1):
		# threshold = minimum cosine similarity (0 to 1)
		self.threshold = threshold

		self.doc_vecs = encode_splade(sentences)




	def attribute_alreadyHaveTheQuotes(self, sentence, top_k=3):
		query_vec = encode_splade([sentence])
		sims = cosine_similarity(query_vec, self.doc_vecs)[0]

		# Sort by similarity
		sorted_indexes = sorted(
			[(score, i) for i, score in enumerate(sims)],
			key=lambda x: x[0],
			reverse=True
		)

		# Collect top_k matches above threshold
		ret = []
		for score, idx in sorted_indexes:
			if score >= self.threshold:
				ret.append(idx)
			if len(ret) >= top_k:
				break

		return [], ret[:1], ret[:2]


