# attribution.py
# Core logic for computing Shapley NEAR scores via attention-based entropy changes

import torch
import torch.nn.functional as F
import numpy as np
from transformers import PreTrainedModel, PreTrainedTokenizer

def compute_entropy(logits: torch.Tensor, token_index: int) -> float:
    token_logits = logits[token_index]
    probs = F.softmax(token_logits, dim=-1)
    log_probs = F.log_softmax(token_logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    return entropy.item()

def compute_shapley_near_score(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, context: str, question: str, M: int = 50) -> float:
    sentences = [s.strip() for s in context.split('.') if s.strip()]
    n = len(sentences)
    question_input = tokenizer(question, return_tensors="pt").to(model.device)
    question_length = question_input.input_ids.shape[1]

    # Compute entropy without context (null)
    null_input = tokenizer(" " + question, return_tensors="pt").to(model.device)
    with torch.no_grad():
        null_output = model(**null_input)
    null_entropy = compute_entropy(null_output.logits[0], token_index=question_length - 1)

    shapley_values = [0.0 for _ in range(n)]
    factorial = np.math.factorial

    for _ in range(M):
        perm = np.random.permutation(n)
        S = []
        used_indices = set()

        for idx in perm:
            prev_context = '. '.join([sentences[i] for i in S]) + '.' if S else ""
            new_context = prev_context + ' ' + sentences[idx] if prev_context else sentences[idx]

            input_prev = tokenizer(prev_context + " " + question, return_tensors="pt").to(model.device)
            input_new = tokenizer(new_context + " " + question, return_tensors="pt").to(model.device)

            with torch.no_grad():
                out_prev = model(**input_prev)
                out_new = model(**input_new)

            entropy_prev = compute_entropy(out_prev.logits[0], token_index=question_length - 1)
            entropy_new = compute_entropy(out_new.logits[0], token_index=question_length - 1)

            ig = null_entropy - entropy_new - (null_entropy - entropy_prev)

            weight = factorial(len(S)) * factorial(n - len(S) - 1) / factorial(n)
            shapley_values[idx] += weight * ig

            S.append(idx)

    shapley_values = [v / M for v in shapley_values]
    return sum(shapley_values) / n
