import numpy as np
import torch
import torch.nn.functional as F

def initial_selection_by_influence(label: np.ndarray, influence: np.ndarray, K: int=10):
    sorted_points = np.argsort(influence)[::-1]
    selected_points = list(sorted_points[:K])
    while label[selected_points].max() < 1:
        selected_points.append(sorted_points[len(selected_points)])
    return selected_points

def estimate_prob(logits_orig, poisoned_label, labels, found_influences):
    logits_cur = logits_orig.clone().detach().float().cpu()
    C = logits_cur.numel()
    mask = torch.ones(C, dtype=torch.bool)
    mask[poisoned_label] = False
    alt_idx = torch.arange(C)[mask][torch.argmax(logits_cur[mask])].item()
    n = len(labels)
    eps = 1.0 / float(n)
    # eps = 1.0
    prob_trace = []
    for infl in found_influences:
        if infl != 0.0:
            logits_cur[poisoned_label] += -eps * float(infl)
            logits_cur[alt_idx]        +=  eps * float(infl)
        p = logit_to_prob_change(logits_cur, 0.0, poisoned_label)
        prob_trace.append(p)
    return prob_trace

def logit_to_prob_change(logits, delta_logit, class_idx):
    if logits is None or class_idx is None:
        return 0.0
    logits_changed = logits.clone()
    logits_changed[class_idx] += delta_logit
    prob_changed = F.softmax(logits_changed - logits_changed.max(), dim=-1)
    return prob_changed[class_idx].item()

@torch.no_grad()
def evaluate_test_acc(model, testloader, device, topk=(1, 5)):
    model.eval()

    total_loss = 0.0
    total_samples = 0
    correct_topk = [0.0 for _ in topk]

    for inputs, labels in testloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)                         
        loss = F.cross_entropy(outputs, labels)          
        B = labels.size(0)

        total_loss += loss.item() * B
        total_samples += B

        maxk = max(topk)
        _, pred = outputs.topk(maxk, dim=1, largest=True, sorted=True)  
        pred = pred.t()                                                 
        correct = pred.eq(labels.view(1, -1).expand_as(pred))           

        for i, k in enumerate(topk):
            correct_k = correct[:k].any(dim=0).float().sum().item()  
            correct_topk[i] += correct_k

    avg_loss = total_loss / total_samples if total_samples > 0 else float('nan')
    accs = [100.0 * c / total_samples for c in correct_topk]

    out = {"loss": avg_loss}
    for i, k in enumerate(topk):
        out[f"top{k}"] = accs[i]
    return out