from sklearn.linear_model import LogisticRegression
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

# Selective Synaptic Dampening (SSD) - https://arxiv.org/abs/2308.07707
# https://github.com/if-loops/selective-synaptic-dampening/blob/main/src/metrics.py#L54
def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
    X_f, Y_f, X_r, Y_r, X_ef, Y_ef, X_er, Y_er = get_membership_attack_data(
        retain_loader, forget_loader, test_loader, model
    )
    clf = LogisticRegression(
        class_weight="balanced", solver="lbfgs", multi_class="multinomial"
    )
    clf.fit(X_r, Y_r) # fit to retain(1) and test(0) data
    conf_mia = clf.predict(X_f).mean()

    clf = LogisticRegression(
        class_weight="balanced", solver="lbfgs", multi_class="multinomial"
    )
    clf.fit(X_er, Y_er) # fit to forget data
    entr_mia = clf.predict(X_ef).mean()

    return {
        "confidence": conf_mia,
        "entropy": entr_mia,
    }


def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):
    print("Collecting probabilities retain...")
    retain_prob, retain_targets = collect_prob(retain_loader, model)
    print("Collecting probabilities forget...")
    forget_prob, forget_targets = collect_prob(forget_loader, model)
    print("Collecting probabilities test...")
    test_prob, test_targets = collect_prob(test_loader, model)
    
    retain_entr = entropy(retain_prob)
    forget_entr = entropy(forget_prob)
    test_entr = entropy(test_prob)

    retain_conf = torch.gather(retain_prob, 1, retain_targets.view(-1, 1))
    forget_conf = torch.gather(forget_prob, 1, forget_targets.view(-1, 1))
    test_conf = torch.gather(test_prob, 1, test_targets.view(-1, 1))

    X_r = torch.cat([retain_conf, test_conf]).cpu().numpy().reshape(-1, 1)
    Xe_r = torch.cat([retain_entr, test_entr]).cpu().numpy().reshape(-1, 1)

    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])
    Y_er = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])

    # X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    X_f = forget_conf.cpu().numpy().reshape(-1, 1)
    X_ef = forget_entr.cpu().numpy().reshape(-1, 1)

    Y_f = np.concatenate([np.ones(len(forget_prob))])
    Y_ef = np.concatenate([np.ones(len(forget_prob))])
    
    return X_f, Y_f, X_r, Y_r, X_ef, Y_ef, Xe_r, Y_er

def entropy(p, dim=-1, keepdim=False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)

def collect_prob(data_loader, model):
    data_loader = torch.utils.data.DataLoader(
        data_loader.dataset, batch_size=500, shuffle=False, num_workers=4
    )

    prob = []
    targets = []
    with torch.no_grad():
        for batch in tqdm(data_loader):
            data, target = batch
            data = data.cuda()
            target = target.cuda()
            output = model(data)
            prob.append(F.softmax(output, dim=-1).data)
            targets.append(target)

    return torch.cat(prob), torch.cat(targets)

