
# -*- coding: utf-8 -*-
"""
Created on Thu Apr  9 02:34:41 2020

@author: linghm
"""

import shutil
from PIL import Image
from skimage import io
import os 
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import scipy
import scipy.misc
import argparse
import sys 
import numpy as np 
import matplotlib.pyplot as plt
import os
import math

import torchvision 
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

import model
from metrics import eval_attack_model
from train import *
from resnet_cifar import *
from metrics import * 

parser = argparse.ArgumentParser(description='membership inference attack')
parser.add_argument('--num_classes', default=2, type=int, metavar='N', help='number of total classes to classify')
parser.add_argument('--num-ensembles', '--ne', default=1, type=int, metavar='N')
parser.add_argument('--noise-coef', '--nc', default=0.0, type=float, metavar='W', help='forward noise')

parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, metavar='W',
                    help='weight decay ')

parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')


parser.add_argument('--epochs', default=100, type=int, metavar='N', help='epochs of training')
parser.add_argument('--batch-size', '--bs', default=128, type=int, metavar='N')
args = parser.parse_args()

k = 2   

print("Python: %s" % sys.version)
print("Pytorch: %s" % torch.__version__)

# determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



class IDC_train(Dataset):
    
    def __init__(self, save_path,transform=None):
        self.save_path=save_path
        self.transform = transform
        n_classes, file_list, class_to_label = self.index()
        self.n_classes = n_classes
        self.file_list = file_list
        self.people_to_idx = class_to_label

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        image=Image.open(img_path).convert('RGB')
        
        label = self.people_to_idx[img_path.split('/')[-2]]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def index(self):
        data_dir = self.save_path
        img_paths=[]
        for p in os.listdir(data_dir):
            for i in os.listdir(os.path.join(data_dir, p)):
                img_paths.append(os.path.join(data_dir,p,i))

        class_list = []
        class_to_idx = {}
        k = 0
        for i in img_paths:
           
            name = i.split('/')[-2]
            if name not in class_to_idx:
                class_list.append(name)
                class_to_idx[name] = k
                k += 1

        n_classes = len(class_list)
        
        seed=42
        np.random.seed(seed)
        np.random.shuffle(img_paths)
        
        file_list = img_paths
        return n_classes, file_list, class_to_idx
    
class IDC_test(Dataset):
    
    def __init__(self, save_path,transform=None):
        self.save_path=save_path
        self.transform = transform
        n_classes, file_list, class_to_label = self.index()
        self.n_classes = n_classes
        self.file_list = file_list
        self.people_to_idx = class_to_label

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        image=Image.open(img_path).convert('RGB')
        label = self.people_to_idx[img_path.split('/')[-2]]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def index(self):
        data_dir = self.save_path
        img_paths=[]
        for p in os.listdir(data_dir):
            for i in os.listdir(os.path.join(data_dir, p)):
                img_paths.append(os.path.join(data_dir,p,i))

        class_list = []
        class_to_idx = {}
        k = 0
        for i in img_paths:
            name = i.split('/')[-2]
            if name not in class_to_idx:
                class_list.append(name)
                class_to_idx[name] = k
                k += 1

        n_classes = len(class_list)
        seed=42
        np.random.seed(seed)
        np.random.shuffle(img_paths)       
        file_list = img_paths
        return n_classes, file_list, class_to_idx
   

target_net_type=en_preactresnet8_cifar(num_ensembles=args.num_ensembles, 
                                    noise_coef=args.noise_coef,num_classes=args.num_classes)
shadow_net_type=en_preactresnet8_cifar(num_ensembles=args.num_ensembles, 
                                    noise_coef=args.noise_coef,num_classes=args.num_classes)
#Load IDC dataset
train_save_path='./data/IDC_train'
test_save_path='./data/IDC_test'

trainset=IDC_train(save_path=train_save_path, 
								transform=transforms.Compose([ 
                                transforms.Resize((32,32)),								
								transforms.ToTensor(),            ]))

trainloader = torch.utils.data.DataLoader(trainset,
						num_workers=0,
						batch_size=args.batch_size)

testset=IDC_test(save_path=test_save_path,
								transform=transforms.Compose([ 
                                transforms.Resize((32,32)),								
								transforms.ToTensor(),            ]))

testloader = torch.utils.data.DataLoader(testset,
						num_workers=0,
						batch_size=100)


def adjust_learning_rate(optimizer, epoch):
    
    if epoch < 20:   
        lr = args.lr
    elif epoch < 40:
        lr = args.lr * 0.25
    elif epoch < 60 :
        lr = args.lr * 0.25 * 0.25
    elif epoch < 80 :
        lr = args.lr * 0.25 * 0.25 * 0.25
    else :
        lr = args.lr * 0.25 * 0.25 * 0.25 * 0.25
   
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
     


def train(model=None, data_loader=None, test_loader=None,
          optimizer=None, criterion=None, n_epochs=0,
          classes=None, verbose=False,flag=None):
    '''
    Function to train a model provided
    specified train/test sets and associated
    training parameters.

    Parameters
    ----------
    model       : Module
                  PyTorch conforming nn.Module function
    data_loader : DataLoader
                  PyTorch dataloader function
    test_loader : DataLoader
                  PyTorch dataloader function
    optimizer   : opt object
                  PyTorch conforming optimizer function
    criterion   : loss object
                  PyTorch conforming loss function
    n_epochs    : int
                  number of training epochs
    classes     : list
                  list of classes
    verbose     : boolean
                  flag for verbose print statements
    '''
    losses = []
    
    
  
    
    for epoch in range(n_epochs):
        model.train()
        adjust_learning_rate(optimizer, epoch)
        for i, batch in enumerate(data_loader):

            data, labels = batch
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(data)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            if verbose:
                print("[{}/{}][{}/{}] loss = {}"
                      .format(epoch, n_epochs, i,
                              len(data_loader), loss.item()))
                

        # evaluate performance on testset at the end of each epoch
        print("[{}/{}]".format(epoch, n_epochs))
        print("Training:")
        train_acc = eval_target_model(model, data_loader, classes=classes)
        print("Test:")
        test_acc = eval_target_model(model, test_loader, classes=classes)
        
    state={
    'state_dict': model.state_dict(),
    }                
    fdir = 'result/IDC_modified/enresnet8_const_noise/SGD'
    if flag=='shadow':
        filepath = os.path.join(fdir, 'shadow_checkpoint_ensemble{}_noise_{}_epoch_{}.pth'.format(args.num_ensembles,args.noise_coef,args.epochs))
        torch.save(state,filepath)
    else:
        filepath = os.path.join(fdir, 'target_checkpoint_ensemble{}_noise_{}_epoch_{}.pth'.format(args.num_ensembles,args.noise_coef,args.epochs))
        torch.save(state,filepath)

       
    return train_acc, test_acc

# In[ ]:


total_size = len(trainset)
split1 = total_size // 4
split2 = split1*2
split3 = split1*3

indices = list(range(total_size))

shadow_train_idx = indices[:split1]
shadow_out_idx = indices[split1:split2]
target_train_idx = indices[split2:split3]
target_out_idx = indices[split3:]

shadow_train_set = torch.utils.data.Subset(trainset, shadow_train_idx)
shadow_out_set=torch.utils.data.Subset(trainset, shadow_out_idx)
target_train_set=torch.utils.data.Subset(trainset, target_train_idx)
target_out_set=torch.utils.data.Subset(trainset, target_out_idx)


shadow_train_loader = torch.utils.data.DataLoader(shadow_train_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
shadow_out_loader = torch.utils.data.DataLoader(shadow_out_set, batch_size=args.batch_size, shuffle=False,num_workers=0)

target_train_loader = torch.utils.data.DataLoader(target_train_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
target_out_loader = torch.utils.data.DataLoader(target_out_set, batch_size=args.batch_size, shuffle=False, num_workers=0)


# the model being attacked 
target_net = target_net_type.to(device)
target_loss = nn.CrossEntropyLoss()
target_optim = optim.SGD(target_net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
#target_optim = optim.Adam(target_net.parameters(), lr=args.lr)


# shadow net mimics the target network 
shadow_net = shadow_net_type.to(device)

shadow_loss = nn.CrossEntropyLoss()
shadow_optim = optim.SGD(shadow_net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
#shadow_optim = optim.Adam(shadow_net.parameters(), lr=args.lr)


# attack net is a binary classifier to determine membership 
attack_net = model.mlleaks_mlp(n_in=k).to(device)
attack_net.apply(model.weights_init)
attack_loss = nn.BCELoss()
attack_optim = optim.Adam(attack_net.parameters(), lr=args.lr)


# In[ ]:
if not os.path.exists('result/IDC_modified/enresnet8_const_noise/SGD'):
    os.makedirs('result/IDC_modified/enresnet8_const_noise/SGD')
fdir = 'result/IDC_modified/enresnet8_const_noise/SGD'

print('train')
train(shadow_net, shadow_train_loader, testloader, shadow_optim, shadow_loss, args.epochs,flag='shadow')

print('train_attacker')
train_attacker(attack_net, shadow_net, shadow_train_loader, shadow_out_loader, attack_optim, attack_loss, n_epochs=50, k=k,verbose=True)

print('train_target')

train(target_net, target_train_loader, testloader, target_optim, target_loss, args.epochs, flag='target')



# In[ ]:

print('eval')
df_pr,df_roc,auc= eval_attack_model(attack_net, target_net, target_train_loader, target_out_loader, k)

print("\nPerformance on training set: ")
train_accuracy = eval_target_model(target_net, target_train_loader, classes=None)

print("\nPerformance on test set: ")
test_accuracy = eval_target_model(target_net, testloader, classes=None)

fdir = 'result/IDC_modified/enresnet8_const_noise/SGD'
filepath = os.path.join(fdir, 'ensemble{}_noise_{}_epoch_{}.pth'.format(args.num_ensembles,args.noise_coef,args.epochs))

state={
            'df_pr': df_pr,
            'df_roc':df_roc,
            'auc':auc,
            
            'train_accuracy': train_accuracy,
            'test_accuracy':test_accuracy,
        }
torch.save(state, filepath)


