#!/usr/bin/env python
# coding: utf-8


# In[ ]:

import argparse
import sys 
import numpy as np 
import matplotlib.pyplot as plt
import os
import math

import torch
import torchvision 
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

sys.path.insert(0, '../../../Utils/')

import model
from metrics import eval_attack_model
from train import *
from model import *
from resnet_cifar import *
from metrics import * 
parser = argparse.ArgumentParser(description='membership inference attack')
parser.add_argument('--num_classes', default=10, type=int, metavar='N', help='number of total classes to classify')
parser.add_argument('--num-ensembles', '--ne', default=1, type=int, metavar='N')
parser.add_argument('--noise-coef', '--nc', default=0.0, type=float, metavar='W', help='forward noise (default: 0.1)')


parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W',
                    help='weight decay (default: 1e-4)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='epochs of training')
parser.add_argument('--batch-size', '--bs', default=128, type=int, metavar='N')

args = parser.parse_args()
k = 3

print("Python: %s" % sys.version)
print("Pytorch: %s" % torch.__version__)

# determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# In[ ]:


target_net_type=en_preactresnet8_cifar(num_ensembles=args.num_ensembles, 
                                    noise_coef=args.noise_coef,num_classes=args.num_classes)

shadow_net_type=en_preactresnet8_cifar(num_ensembles=args.num_ensembles, 
                                    noise_coef=args.noise_coef,num_classes=args.num_classes)




train_transform = torchvision.transforms.Compose([
   
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
    


#Cifar10
# load training set 
print('Loading cifar10')
trainset = torchvision.datasets.CIFAR10("./data", transform=train_transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)

# load test set 
testset = torchvision.datasets.CIFAR10("./data", train=False, transform=test_transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=0)


'''
#Cifar100
# load training set 
print('Loading cifar100')
trainset = torchvision.datasets.CIFAR100("./data", transform=train_transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)

# load test set 
testset = torchvision.datasets.CIFAR100("./data", train=False, transform=test_transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=0)
'''
'''
#MNIST
# load training set 
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),  
    torchvision.transforms.ToTensor()
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor(),
])
trainset = torchvision.datasets.MNIST("./data", transform=train_transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=0)

# load test set 
testset = torchvision.datasets.MNIST("./data", train=False, transform=test_transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=0)
'''


def adjust_learning_rate(optimizer, epoch):
    
    if epoch < 80:   
        lr = args.lr
    elif epoch < 120:
        lr = args.lr * 0.1
    elif epoch <160:
        lr = args.lr * 0.01
    else:
        lr = args.lr * 0.001
        
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def train(model=None, data_loader=None, test_loader=None,
          optimizer=None, criterion=None, n_epochs=0,
          classes=None, verbose=False):
    '''
    Function to train a model provided
    specified train/test sets and associated
    training parameters.

    Parameters
    ----------
    model       : Module
                  PyTorch conforming nn.Module function
    data_loader : DataLoader
                  PyTorch dataloader function
    test_loader : DataLoader
                  PyTorch dataloader function
    optimizer   : opt object
                  PyTorch conforming optimizer function
    criterion   : loss object
                  PyTorch conforming loss function
    n_epochs    : int
                  number of training epochs
    classes     : list
                  list of classes
    verbose     : boolean
                  flag for verbose print statements
    '''
    losses = []
  
    for epoch in range(n_epochs):
        model.train()
        
        adjust_learning_rate(optimizer, epoch)
        for i, batch in enumerate(data_loader):

            data, labels = batch
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(data)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            if verbose:
                print("[{}/{}][{}/{}] loss = {}"
                      .format(epoch, n_epochs, i,
                              len(data_loader), loss.item()))




        # evaluate performance on testset at the end of each epoch
        print("[{}/{}]".format(epoch, n_epochs))
        print("Training:")
        train_acc = eval_target_model(model, data_loader, classes=classes)
        print("Test:")
        test_acc = eval_target_model(model, test_loader, classes=classes)
            
    return train_acc, test_acc

# In[ ]:



total_size = len(trainset)
split1 = total_size // 4
split2 = split1*2
split3 = split1*3

indices = list(range(total_size))
seed=42
np.random.seed(42)
np.random.shuffle(indices)

shadow_train_idx = indices[:split1]
shadow_out_idx = indices[split1:split2]
target_train_idx = indices[split2:split3]
target_out_idx = indices[split3:]


shadow_train_sampler = SubsetRandomSampler(shadow_train_idx)
shadow_out_sampler = SubsetRandomSampler(shadow_out_idx)
target_train_sampler = SubsetRandomSampler(target_train_idx)
target_out_sampler = SubsetRandomSampler(target_out_idx)

shadow_train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, sampler=shadow_train_sampler, num_workers=0)
shadow_out_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, sampler=shadow_out_sampler, num_workers=0)

target_train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, sampler=target_train_sampler, num_workers=0)
target_out_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, sampler=target_out_sampler, num_workers=0)


# the model being attacked 
target_net = target_net_type.to(device)
target_loss = nn.CrossEntropyLoss()
#target_optim = optim.Adam(target_net.parameters(), lr=args.lr)
target_optim = optim.SGD(target_net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


# shadow net mimics the target network
shadow_net = shadow_net_type.to(device)
shadow_loss = nn.CrossEntropyLoss()
#shadow_optim = optim.Adam(shadow_net.parameters(), lr=args.lr)
shadow_optim = optim.SGD(shadow_net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)


# attack net is a binary classifier to determine membership 
attack_net = model.mlleaks_mlp(n_in=k).to(device)
attack_net.apply(model.weights_init)
attack_loss = nn.BCELoss()
attack_optim = optim.Adam(attack_net.parameters(), lr=args.lr)


# In[ ]:





print('train shadow')                 
train(shadow_net, shadow_train_loader, testloader, shadow_optim, shadow_loss, args.epochs)

# In[ ]:

print('train_attacker')
train_attacker(attack_net, shadow_net, shadow_train_loader, shadow_out_loader, attack_optim, attack_loss, n_epochs=50, k=k,verbose=True)#yuanlaishi50


# In[ ]:

print('train_target')
train(target_net, target_train_loader, testloader, target_optim, target_loss, args.epochs)

# In[ ]:

print('eval')
df_pr,df_roc,auc= eval_attack_model(attack_net, target_net, target_train_loader, target_out_loader, k)

print("\nPerformance on training set: ")
train_accuracy = eval_target_model(target_net, target_train_loader, classes=None)

print("\nPerformance on test set: ")
test_accuracy = eval_target_model(target_net, testloader, classes=None)


if not os.path.exists('result/cifar10/enresnet8_const_noise/SGD'):
    os.makedirs('result/cifar10/enresnet8_const_noise/SGD')
fdir = 'result/cifar10/enresnet8_const_noise/SGD'


filepath = os.path.join(fdir, 'ensemble{}_noise{}_epoch{}.pth'.format(args.num_ensembles,args.noise_coef,args.epochs))

state={
            'df_pr': df_pr,
            'df_roc':df_roc,
            'auc':auc,
            
            'train_accuracy': train_accuracy,
            'test_accuracy':test_accuracy,
        }
torch.save(state, filepath)


