

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.metrics.pairwise import cosine_similarity

# Load SPLADE model + tokenizer
model_name = "naver/splade-cocondenser-selfdistil"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
model.eval()

def encode_splade(texts):
	inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
	with torch.no_grad():
		outputs = model(**inputs).logits  # [batch, seq_len, vocab]
	# ReLU + log(1+x) pooling
	relu = torch.relu(outputs)
	weights = torch.log1p(relu).max(dim=1).values  # [batch, vocab]
	return weights.cpu().numpy()

class SPLADE:

	def __init__(self, threshold=0.1):
		# threshold = minimum cosine similarity (0 to 1)
		self.threshold = threshold


	def attribute(self, sentence, quotes, top_k=3):
		query_vec = encode_splade([sentence])
		doc_vecs = encode_splade(quotes)
		sims = cosine_similarity(query_vec, doc_vecs)[0]

		# Sort by similarity
		sorted_indexes = sorted(
			[(score, i) for i, score in enumerate(sims)],
			key=lambda x: x[0],
			reverse=True
		)

		# Collect top_k matches above threshold
		ret = []
		for score, idx in sorted_indexes:
			if score >= self.threshold:
				ret.append(idx)
			if len(ret) >= top_k:
				break

		return [], ret[:1], ret[:2]