import scipy
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import auc, roc_curve

small_delta = 1e-30
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_lira(sample_predictions, sample_labels):
    def load_one(opredictions, labels):
        """
        This loads a logits and converts it to a scored prediction.
        derived from: https://github.com/orientino/lira-pytorch/blob/main/score.py
        """
        # Be exceptionally careful.
        # Numerically stable everything, as described in the paper.
        predictions = opredictions - np.max(opredictions, axis=-1, keepdims=True)
        predictions = np.array(np.exp(predictions), dtype=np.float64)
        predictions = predictions / np.sum(predictions, axis=-1, keepdims=True)

        COUNT = predictions.shape[0]
        y_true = predictions[np.arange(COUNT), labels[:COUNT]]
        predictions[np.arange(COUNT), labels[:COUNT]] = 0

        y_wrong = np.sum(predictions, axis=-1)
        logit = np.log(y_true + 1e-45) - np.log(y_wrong + 1e-45)
        return logit
    return load_one(sample_predictions, sample_labels)


def generate_lira_offline(
        train_member_pred, train_member_label, train_nonmember_pred, train_nonmember_label,
        test_member_pred, test_member_label, test_nonmember_pred, test_nonmember_label,
        N=4
):
    """
    Fit a single predictive model using keep and scores in order to predict
    if the examples in check_scores were training data or not, using the
    ground truth answer from check_keep.
    """
    # dat_in = get_lira(train_member_pred, train_member_label)
    dat_out = get_lira(train_nonmember_pred, train_nonmember_label)
    #out_size = dat_out.shape[0] // N
    #dat_out = np.stack([dat_out[out_size*i:out_size*(i+1)] for i in range(N)])

    mean_out = np.median(dat_out)
    std_out = np.std(dat_out)

    test_member_score = get_lira(test_member_pred, test_member_label)
    test_nonmember_score = get_lira(test_nonmember_pred, test_nonmember_label)
    print(test_member_score.shape, test_nonmember_score.shape)
    pr_in = 0
    pr_out = -scipy.stats.norm.logpdf(test_member_score, mean_out, std_out + 1e-30)
    mem_prediction = pr_in - pr_out
    pr_out = -scipy.stats.norm.logpdf(test_nonmember_score, mean_out, std_out + 1e-30)
    nonmem_prediction = pr_in - pr_out
    print(mem_prediction.shape, nonmem_prediction.shape)
    return mem_prediction, nonmem_prediction


def generate_global(keep, scores, check_keep, check_scores):
    """
    Use a simple global threshold sweep to predict if the examples in
    check_scores were training data or not, using the ground truth answer from
    check_keep.
    """
    prediction = []
    answers = []
    for ans, sc in zip(check_keep, check_scores):
        prediction.extend(-sc.mean(1))
        answers.extend(ans)

    return prediction, answers


def sweep(mem_stat, nonmem_stat, reverse=True):
    """
    Compute a ROC curve and then return the FPR, TPR, AUC, and ACC.
    """
    target = np.concatenate(
        (np.ones(mem_stat.shape[0], dtype=float), np.zeros(nonmem_stat.shape[0], dtype=float)),
        axis=0)
    score = np.concatenate((mem_stat, nonmem_stat), axis=0)
    fpr, tpr, _ = roc_curve(target, score)

    acc = np.max(1 - (fpr + (1 - tpr)) / 2)
    low = tpr[np.where(fpr < 0.001)[0][-1]]
    print("Attack AUC %.4f, Accuracy %.4f, TPR@0.1%%FPR of %.4f" % (auc(fpr, tpr), acc, low))
    return fpr, tpr, auc(fpr, tpr), acc


def sweep_try(in_member_signals, out_member_signals):
    in_member_signals = in_member_signals.reshape(-1, 1)
    out_member_signals = out_member_signals.reshape(-1, 1)
    # create thresholds
    min_signal_val = np.min(np.concatenate([in_member_signals, out_member_signals]))
    max_signal_val = np.max(np.concatenate([in_member_signals, out_member_signals]))
    thresholds = np.linspace(min_signal_val, max_signal_val, 2000)

    member_preds = np.less(in_member_signals, thresholds).T
    non_member_preds = np.less(out_member_signals, thresholds).T

    # what does the attack predict on test and train dataset
    predicted_labels = np.concatenate([member_preds, non_member_preds], axis=1)
    # set true labels for being in the training dataset
    true_labels = np.concatenate(
        [
            np.ones(len(in_member_signals)),
            np.zeros(len(out_member_signals)),
        ]
    )
    #signal_values = np.concatenate(
    #    [in_member_signals, out_member_signals]
    #)

    #accuracy = np.mean(predicted_labels == true_labels, axis=1)
    tn = np.sum(true_labels == 0) - np.sum(
        predicted_labels[:, true_labels == 0], axis=1
    )
    tp = np.sum(predicted_labels[:, true_labels == 1], axis=1)
    fp = np.sum(predicted_labels[:, true_labels == 0], axis=1)
    #fn = np.sum(true_labels == 1) - np.sum(
    #    predicted_labels[:, true_labels == 1], axis=1
    #)
    fpr = fp / (np.sum(true_labels == 0))
    tpr = tp / (np.sum(true_labels == 1))
    # In case the fpr are not sorted in ascending order.
    sorted_indices = np.argsort(fpr)
    fpr = fpr[sorted_indices]
    tpr = tpr[sorted_indices]
    acc = max(np.max(1 - (fpr + (1 - tpr)) / 2), np.max(1 - (tpr + (1 - fpr)) / 2))
    return fpr, tpr, auc(fpr, tpr), acc


def lira_attack(
        train_member_pred, train_member_label, test_member_pred, test_member_label,
        train_nonmember_pred, train_nonmember_label, test_nonmember_pred, test_nonmember_label,
        num_class=100, attack_epochs=150, batch_size=512, num_shadows=4
):

    print("\n\nEvaluating direct single-query attacks :", len(train_member_pred), len(train_nonmember_pred),
          len(test_member_pred), len(test_nonmember_pred))
    print("batch_size", batch_size)
    print(train_member_label[:20])
    print(test_member_label[:20])
    print(test_nonmember_label[:20])
    print(train_nonmember_label[:20])
    print(
        'classifier acc on attack training set: {:.4f}, {:.4f}.\nclassifier acc on attack test set: {:.4f}, {:.4f}.'.format(
            np.mean(np.argmax(train_member_pred, axis=1) == train_member_label),
            np.mean(np.argmax(train_nonmember_pred, axis=1) == train_nonmember_label),
            np.mean(np.argmax(test_member_pred, axis=1) == test_member_label),
            np.mean(np.argmax(test_nonmember_pred, axis=1) == test_nonmember_label)))


    test_mem_stat, test_nonmem_stat = generate_lira_offline(
        train_member_pred, train_member_label, train_nonmember_pred, train_nonmember_label,
        test_member_pred, test_member_label, test_nonmember_pred, test_nonmember_label,
        N=num_shadows
    )
    lira_fpr, lira_tpr, lira_auc, lira_acc = sweep(test_mem_stat, test_nonmem_stat)
    #print('lira_fpr:', lira_fpr)
    #print('lira_tpr', lira_tpr)
    return lira_acc, (lira_fpr, lira_tpr, lira_auc)
