
import torch
from prompt_graph.utils import get_args
import os
import numpy as np
import ipdb
import matplotlib.pyplot as plt
import copy
import scipy
from sklearn.metrics import auc, roc_curve
import torchmetrics
plt.rcParams.update({'font.size': 18})


args = get_args()

def sweep(score, x):
    """
    Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
    """
    fpr, tpr, _ = roc_curve(x, -score)
    acc = np.max(1 - (fpr + (1 - tpr)) / 2)
    return fpr, tpr, auc(fpr, tpr), acc

def membership_inference_attack(train_data, test_data, num_classes):
    # calculate the score for all members & non-members
    accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes).to('cpu')
    feature_train = train_data[:, :-1].reshape(train_data.shape[0], num_classes)
    feature_test = test_data[:, :-1].reshape(test_data.shape[0], num_classes)
    # print(feature_train.shape, feature_test.shape)

    labels_train = train_data[:, -1].reshape(train_data.shape[0]).astype(np.int64)
    labels_test = test_data[:, -1].reshape(test_data.shape[0]).astype(np.int64)

    features = np.concatenate((feature_train, feature_test), axis=0)
    labels = np.concatenate((labels_train, labels_test), axis=0)

    predictions = features - np.max(features, axis=-1, keepdims=True)
    predictions = np.array(np.exp(predictions), dtype=np.float64)
    predictions = predictions / np.sum(predictions, axis=-1, keepdims=True)
    keep_bool = np.concatenate((np.full((labels_train.shape[0]), True), np.full((labels_test.shape[0]), False)), axis=0)
    COUNT = predictions.shape[0]
    y_true = predictions[np.arange(COUNT), labels[:COUNT]]

    predictions[np.arange(COUNT), labels[:COUNT]] = 0
    y_wrong = np.sum(predictions, axis=-1)

    logit = np.log(y_true + 1e-45) - np.log(y_wrong + 1e-45)
    # half results are used for getting in & out distribution
    dat_in = logit[keep_bool][:int(feature_train.shape[0]/2)]
    dat_out = logit[~keep_bool][:int(feature_test.shape[0]/2)]
    mean_in = np.mean(dat_in, axis=0)
    mean_out = np.mean(dat_out, axis=0)

    std_in = np.std(dat_in)
    std_out = np.std(dat_out)

    test_sc = np.concatenate((logit[keep_bool][int(feature_train.shape[0]/2):], logit[~keep_bool][int(feature_test.shape[0]/2):]), axis=0)
    true_len = logit[keep_bool][int(feature_train.shape[0]/2):].shape[0]
    false_len = logit[~keep_bool][int(feature_test.shape[0]/2):].shape[0]
    scores = []
    for sc in test_sc:
        pr_in = -scipy.stats.norm.logpdf(sc, mean_in, std_in + 1e-30)
        pr_out = -scipy.stats.norm.logpdf(sc, mean_out, std_out + 1e-30)
        score = pr_in - pr_out
        scores.append(score)
    answers = np.concatenate((np.full((true_len), True), np.full((false_len), False)), axis=0)
    fpr, tpr, auc, acc = sweep(np.array(scores), answers)

    plt.plot(fpr, tpr, label="{}, auc={:.3f}".format(args.dataset_name, auc), linewidth="1.8")

    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.plot([0, 1], [0, 1], ls="--", color="gray", linewidth="1.8")
    plt.subplots_adjust(bottom=0.18, left=0.18, top=0.96, right=0.96)
    plt.legend(prop={'size': 12}, loc='lower right')
    plt.show()

def main():
    num_classes = 7
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load predited probability of member & non-member
    train_data = []
    test_data = []
    train_data_path = "./dataspace/outs/{}shot/{}_{}/train/{}_{}_{}.npy".format(args.shot_num, args.dataset_name, args.pre_train_data, args.pre_train_type, args.prompt_type, args.gnn_type)
    test_data_path = "./dataspace/outs/{}shot/{}_{}/test/{}_{}_{}.npy".format(args.shot_num, args.dataset_name, args.pre_train_data, args.pre_train_type, args.prompt_type, args.gnn_type)
    train_data = np.load(train_data_path)
    test_data = np.load(test_data_path)

    membership_inference_attack(train_data, test_data, num_classes, device)
            
if __name__ == '__main__':
      main()