

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from utils import *


import shutil


import pytest
import skorch


from mia.estimators import AttackModelBundle
from mia.estimators import ShadowModelBundleIBP
from mia.estimators import prepare_attack_data
from mia.serialization import BaseModelSerializer


from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from skorch.callbacks import EpochScoring
from sklearn.metrics import f1_score


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset,TensorDataset


from models import IBPCNN_V3,IBPCNN_V3_nobayes
from ibp_torch import Server_V3
from miatf2 import get_data,get_mnist_data,get_fmnist_data
import math
import argparse

import pdb



DATASET = "cifar10"
#DATASET = "mnist"
#DATASET = "fmnist"
USE_CUDA = torch.cuda.is_available()
device=('cuda' if USE_CUDA else 'cpu')



class CrossEntropyOneHot(object):
    def __init__(self):
        self.cel = nn.CrossEntropyLoss()

    def __call__(self, out, target):
        _, labels = target.max(dim=1)
        return self.cel(out, labels)

class NLLLossOneHot(object):
    def __init__(self):
        self.nll = nn.NLLLoss()

    def __call__(self, out, target):
        _, labels = target.max(dim=1)
        return self.nll(out, labels)





class ShadowNet(nn.Module):
    # LeNet
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(CHANNELS, 6, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.

        x = x.view(-1, CHANNELS, WIDTH, HEIGHT)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class FactorizedShadowNet(nn.Module):
    # Factorized LeNet
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(CHANNELS, 6, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.

        x = x.view(-1, CHANNELS, WIDTH, HEIGHT)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x




def torch_shadow_model_fn(model):
    model = skorch.NeuralNetClassifier(
            module=model,
            optimizer=torch.optim.Adam,
            device=('cuda' if USE_CUDA else 'cpu'),
            max_epochs=100,
            lr=0.01,
            criterion=CrossEntropyOneHot,
            #criterion=NLLLossOneHot,
            callbacks=[('auc',EpochScoring(scoring='roc_auc',
                lower_is_better=False))],
            train_split=None
    )
    return model


def torch_attack_model_fn():
    model = skorch.NeuralNetClassifier(
        module=AttackNet,
        max_epochs=100,
        optimizer=torch.optim.Adam,
        device=('cuda' if USE_CUDA else 'cpu'),
        criterion=nn.BCELoss,
        callbacks=[('auc',EpochScoring(scoring='roc_auc',
                lower_is_better=False))],

        train_split=None
    )
    return model



def prepare_attack_dataibp(model, data_in, data_out):
    """
    Prepare the data in the attack model format.

    :param model: Classifier
    :param (X, y) data_in: Data used for training
    :param (X, y) data_out: Data not used for training

    :returns: (X, y) for the attack classifier
    """
    X_in, y_in = data_in
    X_out, y_out = data_out
    idx = 0
    batch_size=100
    y_hat_in_lst = list()
    y_hat_out_lst = list()
    with torch.no_grad():
        for i in range(len(data_in[0])//batch_size):
            batch_input = torch.Tensor(X_in[i*batch_size:(i+1) * batch_size]).cuda()
            in_result = model(batch_input)
            y_hat_in_lst.append(in_result)
        y_hat_in = F.softmax(torch.cat(y_hat_in_lst),dim=1)

    with torch.no_grad():
        for i in range(len(data_out[0])//batch_size):
            batch_input = torch.Tensor(X_out[i*batch_size:(i+1) * batch_size]).cuda()
            out_result = model(batch_input)
            y_hat_out_lst.append(out_result)
        y_hat_out = F.softmax(torch.cat(y_hat_out_lst),dim=1)

    y_hat_in = y_hat_in.cpu().numpy()
    y_hat_out = y_hat_out.cpu().numpy()


    labels = np.ones(y_in.shape[0])
    labels = np.hstack([labels, np.zeros(y_out.shape[0])])
    data = np.c_[y_hat_in, y_in]
    data = np.vstack([data, np.c_[y_hat_out, y_out]])
    return data, labels


def prepare_attack_dataibpfull(model, data_in, data_out):
    """
    Prepare the data in the attack model format.

    :param model: Classifier
    :param (X, y) data_in: Data used for training
    :param (X, y) data_out: Data not used for training

    :returns: (X, y) for the attack classifier
    """
    X_in, y_in = data_in
    X_out, y_out = data_out
    idx = 0
    batch_size=100
    y_hat_in_lst = list()
    y_hat_out_lst = list()
    with torch.no_grad():
        for i in range(len(data_in[0])//batch_size):
            batch_input = torch.Tensor(X_in[i*batch_size:(i+1) * batch_size]).cuda()
            in_result = torch.exp(model(batch_input)[0])
            #in_result = model(batch_input)
            y_hat_in_lst.append(in_result)
        y_hat_in = torch.cat(y_hat_in_lst)

    with torch.no_grad():
        for i in range(len(data_out[0])//batch_size):
            batch_input = torch.Tensor(X_out[i*batch_size:(i+1) * batch_size]).cuda()
            out_result = torch.exp(model(batch_input)[0])
            #out_result = model(batch_input)
            y_hat_out_lst.append(out_result)
        y_hat_out = torch.cat(y_hat_out_lst)

    y_hat_in = y_hat_in.cpu().numpy()
    y_hat_out = y_hat_out.cpu().numpy()


    labels = np.ones(y_in.shape[0])
    labels = np.hstack([labels, np.zeros(y_out.shape[0])])
    data = np.c_[y_hat_in, y_in]
    data = np.vstack([data, np.c_[y_hat_out, y_out]])
    return data, labels






def demo(args):

    if DATASET =='cifar10':
        (X_train, y_train), (X_test, y_test) = get_data()
    elif DATASET =='mnist':
        (X_train, y_train), (X_test, y_test) = get_mnist_data()
        X_train = np.expand_dims(X_train, 1)
        X_test = np.expand_dims(X_test, 1)
    elif DATASET == 'fmnist':
        (X_train, y_train), (X_test, y_test) = get_fmnist_data()
        X_train = np.expand_dims(X_train, 1)
        X_test = np.expand_dims(X_test, 1)


    tensor_x_train = torch.Tensor(X_train)
    if DATASET == 'cifar10':
        X_train = np.transpose(X_train, axes=[0,3,1,2])
        X_test = np.transpose(X_test, axes=[0,3,1,2])
        tensor_x_train = torch.Tensor(X_train)
    tensor_y_train = torch.Tensor(np.argmax(y_train, axis=1))

    total_dataset = TensorDataset(tensor_x_train, tensor_y_train)

    trainloader = DataLoader(TensorDataset(tensor_x_train[:1000],tensor_y_train[:1000]),batch_size=args.local_bs, shuffle=True)


    a_prior = args.a_prior
    lambda_prior = args.lambda_prior
    lambda_post = args.lambda_post
    p_threshold = args.p_threshold

    conv_arch = [(args.num_channels, 6, 5), (6, 16, 5)]

    model_arch = [('conv',conv_arch[0][0] * conv_arch[0][2] ** 2, args.conv_truncation, conv_arch[0][1]),
                  ('conv',conv_arch[1][0] * conv_arch[1][2] ** 2, args.conv_truncation, conv_arch[1][1]),

              ('mlp',16*5*5 + 1,80,120),
              ('mlp',121,50,84),
              ('last',84,None,args.num_classes)]



    global_model = Server_V3(model_arch,
                             a_prior,
                             lambda_prior,
                             lambda_post,
                             p_threshold)


    local_model_weights = global_model.send_weights(-1)
    local_bayes = None

    model  = IBPCNN_V3(model_arch,
                       conv_arch,
                       local_model_weights,
                       global_model.a_prior,
                       global_model.lambda_post,
                       global_model.lambda_prior,
                       global_model.p_threshold,
                       local_bayes)

    # first train target model and store the target model
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.NLLLoss().to(device)

    for epo in range(args.target_epochs):
        batch_loss = []
        correct = 0
        total = 0
        for batch_idx, (images, labels) in enumerate(trainloader):

            n_batch = images.shape[0]
            images, labels = images.to(device), labels.to(device)

            model.zero_grad()
            log_probs,kld_binary,kld_v,kld_r = model(images)

            loss = criterion(log_probs, labels.long())
            loss +=  args.klb * kld_binary * 1.0/ n_batch
            loss += args.klv * kld_v * 1.0/ n_batch
            loss += args.klr * kld_r * 1.0/ n_batch
            if math.isnan(loss):
                pdb.set_trace()


            loss.backward()

            if args.clip != 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()

            # Compute train acc
            # Prediction
            _, pred_labels = torch.max(log_probs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels.long())).item()
            total += len(labels)
        print("Epoch " + str(epo) + " accuracy: ", correct / total)


    model.eval()


    data_in = X_train[:args.attack_test_dataset_size], y_train[:args.attack_test_dataset_size]
    data_out = X_test[:args.attack_test_dataset_size], y_test[:args.attack_test_dataset_size]


    attack_test_data, real_membership_labels = prepare_attack_dataibpfull(
       model, data_in, data_out
    )
    np.save("./miaresults/target_attack_test_data",attack_test_data)
    np.save("./miaresults/target_membership",real_membership_labels)

    #torch.save(model.state_dict(), "./miaresults")


    shadow_model  = torch_shadow_model_fn(IBPCNN_V3_nobayes(model_arch,
                       conv_arch,
                       local_model_weights,
                       global_model.a_prior,
                       global_model.lambda_post,
                       global_model.lambda_prior,
                       global_model.p_threshold,
                       local_bayes))


    attacker_X_train, attacker_X_test, attacker_y_train, attacker_y_test = train_test_split(
            X_test[1000:4000], y_test[1000:4000], test_size=0.1
    )


    print(attacker_X_train.shape, attacker_X_test.shape)

    print("Training the shadow models...")
    smb = ShadowModelBundleIBP(
        torch_shadow_model_fn,
        shadow_dataset_size=args.shadow_dataset_size,
        num_models=args.num_shadows,
        model = IBPCNN_V3_nobayes,
        model_args = (model_arch,
                       conv_arch,
                       local_model_weights,
                       global_model.a_prior,
                       global_model.lambda_post,
                       global_model.lambda_prior,
                       global_model.p_threshold,
                       local_bayes)
    )


    X_shadow, y_shadow = smb.fit_transform(
        attacker_X_train,
        attacker_y_train,
        fit_kwargs=dict(
            epochs=args.target_epochs,
            verbose=True,
            validation_data=(attacker_X_test, attacker_y_test),
        ),
    )
    shadow_size = X_shadow.shape[0]
    shadow_predict = np.argmax(X_shadow[np.arange(shadow_size),:10], axis=1)
    shadow_label = np.argmax(X_shadow[np.arange(shadow_size),10:], axis=1)
    print("mean shadow accuracy:", np.mean(shadow_predict==shadow_label))


    train_shadow_predict1 = np.argmax(X_shadow[np.arange(1000),:10], axis=1)
    train_shadow_label = np.argmax(X_shadow[np.arange(1000),10:], axis=1)
    test_shadow_predict1 = np.argmax(X_shadow[np.arange(1000,2000),:10], axis=1)
    test_shadow_label = np.argmax(X_shadow[np.arange(1000,2000),10:], axis=1)
    print("shadow model 1 train accuracy:", np.mean(train_shadow_predict1 ==train_shadow_label))
    print("shadow model 1 test accuracy:", np.mean(test_shadow_predict1 ==test_shadow_label))


    train_shadow_predict2 = np.argmax(X_shadow[np.arange(2000,3000),:10], axis=1)
    train_shadow_label2 = np.argmax(X_shadow[np.arange(2000,3000),10:], axis=1)
    test_shadow_predict2 = np.argmax(X_shadow[np.arange(3000,4000),:10], axis=1)
    test_shadow_label2 = np.argmax(X_shadow[np.arange(3000,4000),10:], axis=1)
    print("shadow model 2 train accuracy:", np.mean(train_shadow_predict2 ==train_shadow_label2))
    print("shadow model 2 test accuracy:", np.mean(test_shadow_predict2 ==test_shadow_label2))



    train_shadow_predict3 = np.argmax(X_shadow[np.arange(4000,5000),:10], axis=1)
    train_shadow_label3 = np.argmax(X_shadow[np.arange(4000,5000),10:], axis=1)
    test_shadow_predict3 = np.argmax(X_shadow[np.arange(5000,6000),:10], axis=1)
    test_shadow_label3 = np.argmax(X_shadow[np.arange(5000,6000),10:], axis=1)

    print("shadow model 3 train accuracy:", np.mean(train_shadow_predict3 ==train_shadow_label3))
    print("shadow model 3 test accuracy:", np.mean(test_shadow_predict3 ==test_shadow_label3))



    amb = AttackModelBundle(LogisticRegression, num_classes=args.num_classes)

    X_shadow, y_shadow = shuffle(X_shadow, y_shadow)
    amb.fit(
        X_shadow, y_shadow
    )


    # Compute the attack accuracy.
    attack_test_data = np.load('./miaresults/target_attack_test_data.npy')
    real_membership_labels = np.load('./miaresults/target_membership.npy')
    attack_guesses = amb.predict(attack_test_data)
    attack_accuracy = np.mean(attack_guesses == real_membership_labels)

    print(attack_accuracy)
    testlabels = np.argmax(attack_test_data[np.arange(len(attack_test_data)), 10:20], axis=1)
    correct = testlabels[attack_guesses == real_membership_labels]
    wrong = testlabels[attack_guesses != real_membership_labels]
    refer = {"correct": correct, "wrong":wrong}
    with open("../save/mia_results/ibpall.pkl",'wb') as f:
        pickle.dump([attack_guesses,real_membership_labels,testlabels],f)



if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--target_epochs', type=int, default=100,
                        help="number of epochs to train target model")

    parser.add_argument('--shadow_dataset_size', type=int, default=1000,help="size of data to train shadow model")

    parser.add_argument('--attack_test_dataset_size', type=int, default=1000,help="size of data for attack test")


    parser.add_argument('--num_shadows', type=int, default=3,
                        help="num of shadow model")

    parser.add_argument('--num_classes', type=int, default=10,
                        help="num of classes")


    parser.add_argument('--attack_epochs', type=int, default=12,
                        help="num of epochs to train attack model")
    parser.add_argument('--num_channels', type=int, default=3,
                        help="num of channels")




    parser.add_argument('--a_prior', type=float, default=50.0, help='a prior')
    parser.add_argument('--lambda_prior', type=float, default=0.66, help='lambda prior')
    parser.add_argument('--lambda_post', type=float, default=0.66, help='lambda post')
    parser.add_argument('--p_threshold', type=float, default=0.4, help='factor threshold')

    parser.add_argument('--truncation', type=int, default=64, help='IBP Truncation parameter')
    parser.add_argument('--conv_truncation', type=int, default=16, help='IBP conv Truncation parameter')

    parser.add_argument('--iid_level', type=int, default=2, help='how many classes each client has')

    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--klr', type=float, default=0.1, help='weight for kl r')
    parser.add_argument('--klb', type=float, default=0.1, help='weight for kl binary')
    parser.add_argument('--klv', type=float, default=0.1, help='weight for kl v')
    parser.add_argument('--clip', type=float, default=0.0, help='gradient clip number')
    parser.add_argument('--local_bs', type=int, default=64,
                        help="local batch size: B")


    args = parser.parse_args()

    demo(args)






