
import os
import shutil


import pytest
#import keras
import skorch
import torch
import numpy as np
import pickle

from torch import nn
from torch.nn import functional as F

from mia.estimators import AttackModelBundle
from mia.estimators import ShadowModelBundle
from mia.estimators import prepare_attack_data
from mia.serialization import BaseModelSerializer



from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from skorch.callbacks import EpochScoring
from miatf2 import get_data,get_mnist_data,get_fmnist_data
import argparse
from sklearn.metrics import f1_score
import pdb

torch.set_default_tensor_type("torch.FloatTensor")



#DATASET = "cifar10"
DATASET = "fmnist"
WIDTH = 32 if DATASET== 'cifar10' else 28
HEIGHT = 32 if DATASET =='cifar10' else 28
CHANNELS = 3 if DATASET=='cifar10' else 1
NUM_CLASSES = 10
NUM_SHADOWS = 3
USE_CUDA = torch.cuda.is_available()

SHADOW_DATASET_SIZE = 4000


def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype="uint8")[y]



class CrossEntropyOneHot(object):
    def __init__(self):
        self.cel = nn.CrossEntropyLoss()

    def __call__(self, out, target):
        _, labels = target.max(dim=1)
        return self.cel(out, labels)

class NLLLossOneHot(object):
    def __init__(self):
        self.nll = nn.NLLLoss()

    def __call__(self, out, target):
        _, labels = target.max(dim=1)
        return self.nll(out, labels)


class ShadowNet(nn.Module):
    # LeNet
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(CHANNELS, 6, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.

        x = x.view(-1, CHANNELS, WIDTH, HEIGHT)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class MShadowNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(CHANNELS, 6, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, NUM_CLASSES)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.

        x = x.view(-1, CHANNELS, WIDTH, HEIGHT)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x




class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x





def torch_shadow_model_fn():
    model = skorch.NeuralNetClassifier(
            module=ShadowNet,
            optimizer=torch.optim.Adam,
            device=('cuda' if USE_CUDA else 'cpu'),
            max_epochs=100,
            lr=0.01,
            criterion=CrossEntropyOneHot,
            callbacks=[('auc',EpochScoring(scoring='roc_auc',
                lower_is_better=False))],
            train_split=None
    )
    return model


def mtorch_shadow_model_fn():
    model = skorch.NeuralNetClassifier(
            module=MShadowNet,
            optimizer=torch.optim.Adam,
            device=('cuda' if USE_CUDA else 'cpu'),
            max_epochs=100,
            lr=0.01,
            criterion=CrossEntropyOneHot,
            callbacks=[('auc',EpochScoring(scoring='roc_auc',
                lower_is_better=False))],
            train_split=None
    )
    return model





class AttackNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(NUM_CLASSES, 128)
        self.drop1 = nn.Dropout(p=0.3)
        self.fc2 = nn.Linear(128, 64)
        self.drop2 = nn.Dropout(p=0.2)
        self.fc3 = nn.Linear(64, 64)
        self.last_linear = nn.Linear(64, 1)

    def forward(self, x, **kwargs):
        del kwargs  # Unused.

        x = F.relu(self.drop1(self.fc1(x)))
        x = F.relu(self.drop2(self.fc2(x)))
        x = F.relu(self.fc3(x))
        x = F.sigmoid(self.last_linear(x))
        return x



def torch_attack_model_fn():
    model = skorch.NeuralNetClassifier(
        module=AttackNet,
        max_epochs=100,
        optimizer=torch.optim.Adam,
        device=('cuda' if USE_CUDA else 'cpu'),
        criterion=nn.BCELoss,
        callbacks=[('auc',EpochScoring(scoring='roc_auc',
                lower_is_better=False))],

        train_split=None
    )
    return model



def demo(args):

    if DATASET =='cifar10':
        (X_train, y_train), (X_test, y_test) = get_data()
    elif DATASET =='fmnist':
        (X_train, y_train), (X_test, y_test) = get_mnist_data()
        X_train = np.expand_dims(X_train, 1)
        X_test = np.expand_dims(X_test, 1)
    else:
        (X_train, y_train), (X_test, y_test) = get_fmnist_data()
        X_train = np.expand_dims(X_train, 1)
        X_test = np.expand_dims(X_test, 1)


    # Train the target model.
    print("Training the target model...")
    if DATASET == 'cifar10':
        target_model = torch_shadow_model_fn()
    elif DATASET in ['mnist','fmnist']:
        target_model = mtorch_shadow_model_fn()
    target_model.fit(
            X_train[:1000], y_train[:1000], epochs=args.target_epochs, validation_split=0.1, verbose=True
    )

    train_accuracy = target_model.score(X_train[:1000],np.argmax(y_train[:1000], axis=1))
    test_accuracy = target_model.score(X_test[:1000],np.argmax(y_test[:1000], axis=1))
    print("target model train accuracy:", train_accuracy)
    print("target model test accuracy:", test_accuracy)
    pdb.set_trace()

    # Train the shadow models.
    if DATASET == 'cifar':
        smb = ShadowModelBundle(
            torch_shadow_model_fn,
            shadow_dataset_size=args.shadow_dataset_size,
            num_models=args.num_shadows,
        )
    elif DATASET in ['mnist','fmnist']:
        smb = ShadowModelBundle(
            mtorch_shadow_model_fn,
            shadow_dataset_size=args.shadow_dataset_size,
            num_models=args.num_shadows,
        )


    # We assume that attacker's data were not seen in target's training.

    attacker_X_train, attacker_X_test, attacker_y_train, attacker_y_test = train_test_split(
            X_test[1000:4000], y_test[1000:4000], test_size=0.1
    )


    print(attacker_X_train.shape, attacker_X_test.shape)

    print("Training the shadow models...")
    X_shadow, y_shadow = smb.fit_transform(
        attacker_X_train,
        attacker_y_train,
        fit_kwargs=dict(
            epochs=args.target_epochs,
            verbose=True,
            validation_data=(attacker_X_test, attacker_y_test),
        ),
    )
    shadow_size = X_shadow.shape[0]
    shadow_predict = np.argmax(X_shadow[np.arange(shadow_size),:10], axis=1)
    shadow_label = np.argmax(X_shadow[np.arange(shadow_size),10:], axis=1)
    print("shadow accuracy:", np.mean(shadow_predict==shadow_label))
    pdb.set_trace()

    # ShadowModelBundle returns data in the format suitable for the AttackModelBundle.

    amb = AttackModelBundle(LogisticRegression, num_classes=args.num_classes)

    # Fit the attack models.
    print("Training the attack models...")
    X_shadow, y_shadow = shuffle(X_shadow, y_shadow)
    amb.fit(
        X_shadow, y_shadow
    )


    # Test the success of the attack.

    # Prepare examples that were in the training, and out of the training.
    data_in = X_train[:args.attack_test_dataset_size], y_train[:args.attack_test_dataset_size]
    data_out = X_test[:args.attack_test_dataset_size], y_test[:args.attack_test_dataset_size]

    # Compile them into the expected format for the AttackModelBundle.
    attack_test_data, real_membership_labels = prepare_attack_data(
        target_model, data_in, data_out
    )

    # Compute the attack accuracy.
    attack_guesses = amb.predict(attack_test_data)
    attack_accuracy = np.mean(attack_guesses == real_membership_labels)

    print("attack accuracy: ", attack_accuracy)
    print("attack F1: ", f1_score(real_membership_labels, attack_guesses))

    testlabels = np.argmax(attack_test_data[np.arange(len(attack_test_data)), 10:20], axis=1)
    correct = testlabels[attack_guesses == real_membership_labels]
    wrong = testlabels[attack_guesses != real_membership_labels]
    refer = {"correct": correct, "wrong":wrong}




if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--target_epochs', type=int, default=100,
                        help="number of epochs to train target model")

    parser.add_argument('--shadow_dataset_size', type=int, default=1000,help="size of data to train shadow model")

    parser.add_argument('--attack_test_dataset_size', type=int, default=1000,help="size of data for attack test")


    parser.add_argument('--num_shadows', type=int, default=3,
                        help="num of shadow model")

    parser.add_argument('--num_classes', type=int, default=10,
                        help="num of classes")


    parser.add_argument('--attack_epochs', type=int, default=12,
                        help="num of epochs to train attack model")

    args = parser.parse_args()

    demo(args)






