#DEPENDENCIES
import argparse
import sys

import torch
import torch.nn as nn

from torchvision import datasets, transforms
#from kymatio.torch import Scattering2D
import os
import pickle
import numpy as np
import scipy.stats as stats
import math
import logging



from opacus import PrivacyEngine

import torch.nn.functional as F

import math
import opacus.privacy_analysis as tf_privacy


from copy import deepcopy

import lira_training

import seaborn as sns 
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams.update({'font.size': 18})


def get_losses(target_model, test_dataset, n_samples, device):
    inds = np.random.randint(len(test_dataset), size = n_samples)

    #print(f"Indices {inds}")

    losses = []

    for ind in inds:
        data, target = test_dataset[ind]
        data = data.to(device)
        target = torch.tensor([target]).to(device)

        output = target_model(data.unsqueeze(0))
        loss = F.cross_entropy(output, target)

        losses.append(loss.detach().item())

    losses_np = np.array(losses)

    return losses_np

def get_threshold(losses, alpha):

    sorted_losses = np.sort(losses)

    threshold = sorted_losses[int(alpha * len(losses))]

    return threshold

def run_attack_precision(threshold, sampling, target_model, train_dataset, n_points, device):

    
    losses = get_losses(target_model, train_dataset, n_points, device)
    positives = losses <= threshold
    correct = positives * sampling[:n_points]

    print(f"Min train loss {np.min(losses)}")

    print(np.sum(positives), np.sum(correct))

    if np.sum(positives) == 0:
        precision = 0

    else:
        precision = np.sum(correct) / np.sum(positives)

    return precision


if __name__ == "__main__":

    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
    parser.add_argument('--calibration_points', type = int, default = 1000)
    #parser.add_argument('--nan_allowed', default = "no", choices = ["yes","no"])

    args = parser.parse_args()
    dataset = args.dataset
    cal_points = args.calibration_points

    epsilons = [1.0, 3.0, 5.0, 9.0, 15.0]
    P_x = [1,3,5,7,9]
    Trials = [1,2,3,4,5]

    Sampling_Folder = f'./lira_samplings'
    Models_Folder = f"./lira_models"

    #Get Dataset
    train_dataset, test_dataset = lira_training.get_data(args.dataset)

    #print(test_dataset[0], test_dataset[1])

    #Get Model to load checkpoints
    input_norm = None
    size = None
    num_groups = int(81)
    scattering = None
    K = 3 if len(train_dataset.data.shape) == 4 else 1

    model = lira_training.CNNS[args.dataset](K, input_norm=input_norm, num_groups=num_groups, size=size)
    model.to(device)

    prec_trials = []

    alpha = 0.1

    for Trial in Trials:

        print(f"On Trial {Trial}")
        prec_epsilons = []

        for eps in epsilons:

            prec_samples = []
            for i in P_x:

                target_model_path = Models_Folder + "/" + f"model_{dataset}_{int(eps)}_target_{i}_{Trial}_{0}.pt"
                sampling = np.load(f'lira_samplings/sampling_{dataset}_{int(eps)}_target_{i}_{Trial}_0.npy')
                
                model.load_state_dict(torch.load(target_model_path))

                test_losses = get_losses(model, test_dataset, cal_points, device)

                print(f"Max test loss {np.max(test_losses)}")
                print(f"Min test losses {np.min(test_losses)}")

                threshold = get_threshold(test_losses, alpha)

                print(f"Threshold {threshold}")

                prec = run_attack_precision(threshold, sampling, model, train_dataset, 1000, device)
                prec_samples.append(prec)

            prec_epsilons.append(prec_samples)

        prec_trials.append(prec_epsilons)

    prec_trials_np = np.array(prec_trials)
    prec_trials_np_avg = np.mean(prec_trials_np, axis = 0)
    prec_trials_np_std = np.std(prec_trials_np, axis = 0)
    prec_trials_np_ci = (1.96 / math.sqrt(5)) * prec_trials_np_std

    print("Plotting")

    for i,eps in enumerate(epsilons):
        plt.plot([0.1,0.3,0.5,0.7,0.9], prec_trials_np_avg[i], label = f"$\epsilon = {eps}$")
        plt.fill_between([0.1,0.3,0.5,0.7,0.9], prec_trials_np_avg[i] - prec_trials_np_ci[i], prec_trials_np_avg[i] + prec_trials_np_ci[i], alpha = 0.1)

    plt.plot([0.1,0.3,0.5,0.7,0.9], [0.1,0.3,0.5,0.7,0.9], label = 'baseline', linestyle='dashed')

    plt.xlabel("Sampling Probability $P_{x^*}(1)$")
    plt.ylabel("Positive Accuracy")
    plt.title(f"Attack P on {dataset}")

    plt.legend()

    print("Saving Figure")
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_Attack_P_POS_ACC_5_trials.pdf')











