#DEPENDENCIES
import argparse
import sys

import torch
import torch.nn as nn

from torchvision import datasets, transforms
#from kymatio.torch import Scattering2D
import os
import pickle
import numpy as np
import scipy.stats as stats
import math
import logging



from opacus import PrivacyEngine

import torch.nn.functional as F

import math
import opacus.privacy_analysis as tf_privacy


from copy import deepcopy

import lira_training

if __name__ == "__main__":

    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Settings')
    #parser.add_argument('--model_type', default = 'target', choices=['target','shadow'])
    parser.add_argument('--P_x', default=0.5, type = float)
    parser.add_argument('--target_epsilon', default = 3.0, type = float)
    parser.add_argument('--dataset', default = 'mnist', choices=['cifar10', 'fmnist', 'mnist', 'svhn_ext'])
    parser.add_argument('--Trial',default=0, type=int)
    #parser.add_argument('--model_number', default = 0, type=int)

    args = parser.parse_args()

    P_x = args.P_x
    dataset = args.dataset
    target_epsilon = args.target_epsilon
    Trial = args.Trial


    Sampling_Folder = f'./lira_samplings'
    Models_Folder = f"./lira_models"

    target_model_path = Models_Folder + "/" + f"model_{dataset}_target_{int(P_x*10)}_{Trial}_{0}_no_dp.pt"

    #Get Dataset
    train_data_before, test_data = lira_training.get_data(args.dataset)

    #Get Model to load checkpoints
    input_norm = None
    size = None
    num_groups = int(81)
    scattering = None
    K = 3 if len(train_data_before.data.shape) == 4 else 1

    model = lira_training.CNNS[args.dataset](K, input_norm=input_norm, num_groups=num_groups, size=size)
    model.to(device)

    #Initialize Scores List
    scores = []
    n_scores = 1000
    
    #Compute scores for each datapoints in training set
    print("Computing Scores")
    #for i,(data,target) in enumerate(train_data_before):
    for i in range(n_scores):

        if i%100 == 0:
            print(f"On {i} out of {n_scores-1}")

        confs_in = []
        confs_out = []

        data, target = train_data_before[i]
        data = data.to(device)

        #First compute observed scaling
        model.load_state_dict(torch.load(target_model_path))

        output = model(data.unsqueeze(0))
        confidence_vec = nn.functional.softmax(output[0], dim = 0)

        obs_confs = confidence_vec[target].detach().cpu().item()

        obs_scaled = math.log(obs_confs + 1e-45) - math.log(1 - obs_confs + 1e-45)
        


        #Now Compute shadow model confidences, and classify them as in our out accordingly
        for j in range(20):
            shadow_model_path = Models_Folder + "/" + f"model_{dataset}_shadow_{int(0.5*10)}_{Trial}_{j}_no_dp.pt"
            sampling_path = Sampling_Folder + "/" + f"sampling_{dataset}_shadow_{int(0.5*10)}_{Trial}_{j}_no_dp.npy"

            sampling = np.load(sampling_path)
            model.load_state_dict(torch.load(shadow_model_path))

            output = model(data.unsqueeze(0))
            confidence_vec = nn.functional.softmax(output[0], dim = 0)

            confidence = confidence_vec[target].detach().cpu().item()

            if sampling[i]:
                confs_in.append(confidence)

            else:
                confs_out.append(confidence)

        #Convery shadow model confidences to scalings
        confs_in_np = np.array(confs_in)
        confs_out_np = np.array(confs_out)

        safety_in = 1e-45 * np.ones(len(confs_in))
        safety_out = 1e-45 * np.ones(len(confs_out))

        scaled_in = np.log(confs_in_np + safety_in) - np.log(np.ones(len(confs_in))- confs_in_np + safety_in)
        scaled_out = np.log(confs_out_np + safety_out) - np.log(np.ones(len(confs_out))- confs_out_np + safety_out)

        #Compute Means and Variances
        mean_in = np.mean(scaled_in)
        mean_out = np.mean(scaled_out)

        var_in = np.var(scaled_in)
        var_out = np.var(scaled_out)


        #Compute score and append
        #Note: scipy.norm.pdf excpet standard deviation, not variance
        pdf_in = stats.norm.pdf(obs_scaled, mean_in, math.sqrt(var_in))
        pdf_out = stats.norm.pdf(obs_scaled, mean_out, math.sqrt(var_out))

        print(pdf_in, pdf_out)

        score = pdf_in / pdf_out

        print(score)

        scores.append(score)


    #Save the Scores
    print("Saving Scores")
    scores_np = np.array(scores)

    Scores_Folder = './lira_scores'
    if not os.path.exists(Scores_Folder):
        
        # if the demo_folder directory is not present 
        # then create it.
        os.makedirs(Scores_Folder)

    scores_path = Scores_Folder + "/" + f"scores_1000_{dataset}_{int(P_x*10)}_{Trial}_no_dp"

    np.save(scores_path, scores_np)





            
            



