#DEPENDENCIES
import argparse
import sys

import torch
import torch.nn as nn

from torchvision import datasets, transforms
#from kymatio.torch import Scattering2D
import os
import pickle
import numpy as np
import scipy.stats as stats
import math
import logging



from opacus import PrivacyEngine

import torch.nn.functional as F

import math
import opacus.privacy_analysis as tf_privacy


from copy import deepcopy

import lira_training

import seaborn as sns 
import matplotlib
from matplotlib import pyplot as plt
matplotlib.rcParams.update({'font.size': 18})

def get_losses(target_model, test_dataset, n_samples, device):
    inds = np.random.randint(len(test_dataset), size = n_samples)

    #print(f"Indices {inds}")

    losses = []

    for ind in inds:
        data, target = test_dataset[ind]
        data = data.to(device)
        target = torch.tensor([target]).to(device)

        output = target_model(data.unsqueeze(0))
        loss = F.cross_entropy(output, target)

        losses.append(loss.detach().item())

    losses_np = np.array(losses)

    return losses_np


def get_shadow_losses(index, data, target, model, dataset, target_epsilon, P_x_shadow, Trial, Models_Folder, Sampling_Folder, device ):

    losses = []

    for j in range(20):
        shadow_model_path = Models_Folder + "/" + f"model_{dataset}_{int(target_epsilon)}_shadow_{int(P_x_shadow*10)}_{Trial}_{j}.pt"
        sampling_path = Sampling_Folder + "/" + f"sampling_{dataset}_{int(target_epsilon)}_shadow_{int(P_x_shadow*10)}_{Trial}_{j}.npy"

        sampling = np.load(sampling_path)
        model.load_state_dict(torch.load(shadow_model_path))

        if sampling[index] == 0:
            data = data.to(device)
            target = torch.tensor([target]).to(device)

            output = model(data.unsqueeze(0))
            loss = F.cross_entropy(output, target)

            losses.append(loss.detach().item())

    return np.array(losses)

def get_threshold(losses, alpha):

    sorted_losses = np.sort(losses)

    threshold = sorted_losses[int(alpha * len(losses))]

    return threshold

def run_attack_precision(alpha, model, train_dataset, n_points, device, dataset, target_epsilon, P_x, Trial, Models_Folder, Sampling_Folder, use_var_sampling):

    target_model_path = Models_Folder + "/" + f"model_{dataset}_{int(target_epsilon)}_target_{int(P_x*10)}_{Trial}_{0}.pt"
    target_sampling = np.load(f'lira_samplings/sampling_{dataset}_{int(target_epsilon)}_target_{int(P_x*10)}_{Trial}_0.npy')
    target_sampling = target_sampling[:n_points]

    model.load_state_dict(torch.load(target_model_path))
    print("Computing Target Losses")
    target_losses = get_losses(model, train_dataset, n_points, device)

    positive = []
    print("Computing threshold per point")
    for i in range(n_points):
        data, target = train_dataset[i]

        #NOTE: Right now hard-coding P_x_shadow to be 0.5
        if use_var_sampling == 'yes':
            P_x_shadow = P_x

        else:
            P_x_shadow = 0.5

        shadow_losses = get_shadow_losses(i, data, target, model, dataset, target_epsilon, P_x_shadow, Trial, Models_Folder, Sampling_Folder, device )

        #We will only consider test points for which we had some shadow models who did not sample it, which effectively says we predict not a member for those without any shadow models.
        
        if len(shadow_losses) > 0:
            threshold = get_threshold(shadow_losses, alpha)

            if target_losses[i] <= threshold:
                positive.append(1)
            else:
                positive.append(0)

        else:
            positive.append(0)

    positive = np.array(positive)
    correct = positive * target_sampling
    
    if np.sum(positive) == 0:
        precision = 0

    else:
        precision = np.sum(correct) / np.sum(positive)

    return precision




if __name__ == "__main__":

    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
    parser.add_argument('--n_points', type = int, default = 1000)
    parser.add_argument('--alpha', type = float, default = 0.1)
    parser.add_argument('--use_var_sampling', default = 'no', choices = ['yes','no'])
    #parser.add_argument('--nan_allowed', default = "no", choices = ["yes","no"])

    args = parser.parse_args()
    dataset = args.dataset
    n_points = args.n_points
    alpha = args.alpha
    use_var_sampling = args.use_var_sampling

    if use_var_sampling == 'yes':
        epsilons = [3.0,7.0]
        P_x = [0.1,0.3,0.5,0.7,0.9]
    
    else:
        epsilons = [1.0,3.0,9.0,15.0]
        P_x = [0.1,0.3,0.5,0.7,0.9]

    Trials = [1,2,3,4,5]

    Sampling_Folder = f'./lira_samplings'
    Models_Folder = f"./lira_models"

    #Get Dataset
    train_dataset, test_dataset = lira_training.get_data(args.dataset)

    #Get Model to load checkpoints
    input_norm = None
    size = None
    num_groups = int(81)
    scattering = None
    K = 3 if len(train_dataset.data.shape) == 4 else 1

    model = lira_training.CNNS[args.dataset](K, input_norm=input_norm, num_groups=num_groups, size=size)
    model.to(device)

    prec_trials = []

    #alpha = 0.1

    for Trial in Trials:

        print(f"On Trial {Trial}")
        prec_epsilons = []

        for eps in epsilons:

            prec_samples = []
            for p_x in P_x:
                
                prec = run_attack_precision(alpha, model, train_dataset, n_points, device, dataset, eps, p_x, Trial, Models_Folder, Sampling_Folder, use_var_sampling)

                print(f"Precision {prec}")

                prec_samples.append(prec)
            
            prec_epsilons.append(prec_samples)
        
        prec_trials.append(prec_epsilons)

    prec_trials_np = np.array(prec_trials)
    prec_trials_np_avg = np.mean(prec_trials_np, axis = 0)
    prec_trials_np_std = np.std(prec_trials_np, axis = 0)
    prec_trials_np_ci = (1.96 / math.sqrt(5)) * prec_trials_np_std

    print("Plotting")

    for i,eps in enumerate(epsilons):
        plt.plot([0.1,0.3,0.5,0.7,0.9], prec_trials_np_avg[i], label = f"$\epsilon = {eps}$")
        plt.fill_between([0.1,0.3,0.5,0.7,0.9], prec_trials_np_avg[i] - prec_trials_np_ci[i], prec_trials_np_avg[i] + prec_trials_np_ci[i], alpha = 0.1)

    plt.plot([0.1,0.3,0.5,0.7,0.9], [0.1,0.3,0.5,0.7,0.9], label = 'baseline', linestyle='dashed')

    plt.xlabel("Sampling Probability $P_{x^*}(1)$")
    plt.ylabel("Positive Accuracy")
    plt.title(f"Attack R on {dataset} with Alpha {alpha}")

    plt.legend()

    print("Saving Figure")
    plt.tight_layout()
    plt.savefig(f'Plots/{dataset}_Attack_R_{alpha}_{use_var_sampling}_POS_ACC_5_trials.pdf')