# attribution_ame.py
# AME-based Shapley NEAR via coalition sampling + L1 regression (no permutations)

import math
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer

def _compute_last_token_entropy(logits: torch.Tensor) -> float:
    """Entropy at the final position of the sequence (vocab dimension)."""
    token_logits = logits[0, -1]               # shape: [V]
    probs = F.softmax(token_logits, dim=-1)
    log_probs = F.log_softmax(token_logits, dim=-1)
    return float(-(probs * log_probs).sum().item())

def _encode(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str) -> float:
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        out = model(**inputs)
    return _compute_last_token_entropy(out.logits)

def compute_ame_near_score(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    context: str,
    question: str,
    M: int = 50,
    lambda_reg: float = 0.01,
    rng: np.random.Generator | None = None,
) -> float:
    """
    AME estimator for NEAR:
      1) sample M random coalitions (subsets) of sentences,
      2) y_j = IG(S_j) = H_empty - H_{S_j},
      3) solve LASSO:  min_phi (1/(2M))||y - X phi||_2^2 + lambda * ||phi||_1,
      4) NEAR = (1/n) * sum_i phi_i_hat.
    """
    if rng is None:
        rng = np.random.default_rng()

    # Sentence list
    sentences = [s.strip() for s in context.split('.') if s.strip()]
    n = len(sentences)
    if n == 0:
        # No context => NEAR is zero by construction
        return 0.0

    # Baseline entropy with empty context
    H_empty = _encode(model, tokenizer, " " + question)

    # Build design matrix X (M x n) and response vector y (M,)
    X = np.zeros((M, n), dtype=np.float32)
    y = np.zeros(M, dtype=np.float32)

    # Pre-tokenize sentence strings (simple join with ". ")
    # You can swap this splitter for a more robust sentence segmenter if needed.
    for j in range(M):
        # Sample a non-empty random coalition (uniform over {1..n} sized subsets)
        # Draw a mask by independent Bernoulli(0.5) and resample if empty
        mask = rng.integers(low=0, high=2, size=n, dtype=np.int8)
        while mask.sum() == 0:
            mask = rng.integers(low=0, high=2, size=n, dtype=np.int8)

        X[j, :] = mask

        # Build context prefix for this coalition
        chosen = [sentences[i] for i in range(n) if mask[i] == 1]
        prefix = ". ".join(chosen)
        text = (prefix + ". " if prefix else "") + question

        # Entropy with coalition S
        H_S = _encode(model, tokenizer, text)

        # Information gain for this coalition (relative to empty)
        y[j] = H_empty - H_S

    # Solve L1-regularized least-squares (LASSO) to get phi-hat
    # Prefer scikit-learn if available; otherwise use a simple proximal gradient fallback.
    phi_hat = None
    try:
        from sklearn.linear_model import Lasso
        # scikit's objective is (1/(2*n_samples))||y-Xw||^2 + alpha||w||_1
        # Match lambda_reg by setting alpha=lambda_reg (since we already scale by 1/(2M) above).
        lasso = Lasso(alpha=lambda_reg, fit_intercept=False, max_iter=10000, tol=1e-6)
        lasso.fit(X, y)
        phi_hat = lasso.coef_.astype(np.float32)
    except Exception:
        # Proximal gradient (ISTA) fallback
        XtX = X.T @ X / M
        Xty = X.T @ y / M
        # Lipschitz constant (largest eigenvalue upper bound via trace)
        L = float(np.trace(XtX))
        step = 1.0 / (L + 1e-8)
        w = np.zeros(n, dtype=np.float32)
        for _ in range(2000):
            grad = XtX @ w - Xty
            w = w - step * grad
            # Soft-threshold
            thr = step * lambda_reg
            w = np.sign(w) * np.maximum(np.abs(w) - thr, 0.0)
        phi_hat = w

    # NEAR estimate is the average contribution
    near_hat = float(phi_hat.mean())
    return near_hat
