from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from transfer_data import train_test_sets_generation, train_test_sets_generation_fashionmnist
from sklearn.metrics import roc_auc_score
from opacus import PrivacyEngine

import os
from itertools import count
import time
import random
import numpy as np
import math

from models.models import *
from models.preact_resnet import *

from torchvision.utils import save_image

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

if not torch.cuda.is_available():
    print('cuda is required but cuda is not available')
    exit()

#== parser start
parser = argparse.ArgumentParser(description='PyTorch')
# base setting 1: fixed
parser.add_argument('--job-id', type=int, default=1)
parser.add_argument('--seed', type=int, default=10000)
# base setting 2: fixed
parser.add_argument('--test-batch-size', type=int, default=10000)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--data-path', type=str, default='./dataset/')                    
# experiment setting
parser.add_argument('--dataset', type=str, default='fashionmnist_binary') #mnist_binary, mnist
parser.add_argument('--data-aug', type=int, default=0) 
parser.add_argument('--model', type=str, default='AUC_MLP') #LeNet AUC_LeNet
# method setting
parser.add_argument('--lr', type=float, default=100)
parser.add_argument('--lr_c', type=float, default=210)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sigma', type=float, default=1.256)
parser.add_argument('--proj_c', type=float, default=2)
parser.add_argument('--grad_w_max', type=float, default=0.00024)
parser.add_argument('--grad_a_max', type=float, default=0.0000197)
parser.add_argument('--grad_b_max', type=float, default=0.0000257)
parser.add_argument('--grad_c_max', type=float, default=0.0000434)
parser.add_argument('--ssize', type=int, default=64)
parser.add_argument('--method', type=int, default=0) 
                    # --method=0: standard
                    # --method=1: q-SGD 
args = parser.parse_args()                    
#== parser end
data_path = args.data_path + args.dataset
if not os.path.isdir(data_path):
    os.makedirs(data_path)

result_path = './results/'    
if not os.path.isdir(result_path):
    os.makedirs(result_path)
result_path += args.dataset + '_' + str(args.data_aug) + '_' + args.model
result_path += '_' + str(args.method) + '_' + str(args.batch_size)
if args.method != 0:
    result_path += '_' + str(args.ssize) 
result_path += '_' + str(args.job_id)
filep = open(result_path + '.txt', 'w')
with open(__file__) as f: 
    filep.write('\n'.join(f.read().split('\n')[1:]))
filep.write('\n\n')    

out_str = str(args)
print(out_str)
filep.write(out_str + '\n') 

if args.seed is None:
  # args.seed = random.randint(1, 10000)
  # args.seed = 605
  args.seed = 10000
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.enabled = True

out_str = 'initial seed = ' + str(args.seed)
print(out_str)
filep.write(out_str + '\n\n')

#===============================================================
#=== dataset setting
#===============================================================
kwargs = {}
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_Sampler = None
test_Sampler = None
Shuffle = True
if args.dataset == 'mnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 50
    if args.data_aug == 1:        
        end_epoch = 200 
        train_transform = transforms.Compose([
                            transforms.RandomCrop(28, padding=2),
                            transforms.RandomAffine(15, scale=(0.85, 1.15)),
                            transforms.ToTensor()       
                       ])                
    train_data = datasets.MNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.MNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'cifar10':
    nh = 32
    nw = 32
    nc = 3
    num_class = 10 
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200 
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'cifar100':
    nh = 32
    nw = 32
    nc = 3
    num_class = 100
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200    
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])
    train_data = datasets.CIFAR100(root=data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.CIFAR100(root=data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'svhn':
    nh = 32
    nw = 32
    nc = 3
    num_class = 10
    end_epoch = 50    
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])
    train_data = datasets.SVHN(data_path, split='train', download=True, transform=train_transform)
    test_data = datasets.SVHN(data_path, split='test', download=True, transform=test_transform)    
elif args.dataset == 'fashionmnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 20
    if args.data_aug == 1:
        end_epoch = 200       
        train_transform = transforms.Compose([
            transforms.RandomCrop(28, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]) 
    train_data = datasets.FashionMNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.FashionMNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'kmnist':
    nh = 28
    nw = 28
    nc = 1
    num_class = 10
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
                            transforms.RandomCrop(28, padding=2),
                            transforms.ToTensor()       
                       ])
    train_data = datasets.KMNIST(data_path, train=True, download=True, transform=train_transform)
    test_data = datasets.KMNIST(data_path, train=False, download=True, transform=test_transform)
elif args.dataset == 'semeion':
    nh = 16
    nw = 16
    nc = 1
    num_class = 10 # the digits from 0 to 9 (written by 80 people twice)    
    end_epoch = 50
    if args.data_aug == 1:
        end_epoch = 200
        train_transform = transforms.Compose([
            transforms.RandomCrop(16, padding=1),
            transforms.RandomAffine(4, scale=(1.05, 1.05)),
            transforms.ToTensor()
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor()
        ])    
    train_data = datasets.SEMEION(data_path, transform=train_transform, download=True) 
    test_data = train_data    
    random_index = np.load(data_path+'/random_index.npy')
    train_size = 1000    
    train_Sampler = SubsetRandomSampler(random_index[range(train_size)])
    test_Sampler = SubsetRandomSampler(random_index[range(train_size,len(test_data))])
    Shuffle = False
elif args.dataset == 'fakedata':
    nh = 24
    nw = 24
    nc = 3
    num_class = 10  
    end_epoch = 50   
    train_size = 1000
    test_size = 1000
    train_data = datasets.FakeData(size=train_size+test_size, image_size=(nc, nh, nw), num_classes=num_class, transform=train_transform)
    test_data  = train_data 
    train_Sampler = SubsetRandomSampler(range(train_size))
    test_Sampler = SubsetRandomSampler(range(train_size,len(test_data)))
    Shuffle = False
elif args.dataset == 'mnist_binary':
    nh = 28
    nw = 28
    nc = 1
    num_class = 1
    end_epoch = 10
    train_data, test_data, n_train_pos, n_train_neg = train_test_sets_generation(args.dataset, data_path)
    Shuffle = False
elif args.dataset == 'fashionmnist_binary':
    nh = 28
    nw = 28
    nc = 1
    num_class = 1
    end_epoch = 10
    train_data, test_data, n_train_pos, n_train_neg = train_test_sets_generation_fashionmnist(args.dataset, data_path)
    Shuffle = False
else: 
    print('specify dataset')
    exit()   
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,      sampler=train_Sampler, shuffle=Shuffle, **kwargs)
test_loader  = torch.utils.data.DataLoader(test_data,  batch_size=args.test_batch_size, sampler=test_Sampler,  shuffle=False,   **kwargs)

#===============================================================
#=== model setting
#===============================================================
if args.model == 'LeNet':
    model = LeNet(nc, nh, nw, num_class).cuda()
elif args.model == 'PreActResNet18':
    model = PreActResNet18(nc, num_class).cuda()
elif args.model == 'Linear' or args.model == 'SVM':
    dx = nh * nw * nc     
    model = Linear(dx, num_class).cuda()
elif args.model == 'AUC_LeNet':
    model = LeNet(nc, nh, nw, num_class).cuda()
elif args.model == 'AUC_MLP':
    dx = nh * nw * nc
    model = MLP(dx, num_class).cuda()
elif args.model == 'AUC_linear':
    dx = nh * nw * nc
    model = Linear(dx, num_class).cuda()
else:
    print('specify model')
    exit() 
    
#===============================================================
#=== utils def
#===============================================================
def lr_decay_func(optimizer, lr_decay=0.1):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= 0.1   
    return optimizer    
def lr_scheduler(optimizer, epoch, lr_decay=0.1, interval=10):
    if args.data_aug == 0:
        if epoch == 10 or epoch == 50:
            optimizer = lr_decay_func(optimizer, lr_decay=lr_decay) 
    if args.data_aug == 1:
        if epoch == 10 or epoch == 100:
            optimizer = lr_decay_func(optimizer, lr_decay=lr_decay)                   
    return optimizer
def lr_decay_func_AUC(optimizer, epoch):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= (1/epoch)
    return optimizer

class multiClassHingeLoss(nn.Module):
    def __init__(self):
        super(multiClassHingeLoss, self).__init__()
    def forward(self, output, y):
        index = torch.arange(0, y.size()[0]).long().cuda()
        output_y = output[index, y.data.cuda()].view(-1,1)
        loss = output - output_y + 1.0 
        loss[index, y.data.cuda()] = 0
        loss[loss < 0]=0
        loss = torch.sum(loss, dim=1) / output.size()[1]
        return loss 
hinge_loss = multiClassHingeLoss()
    
#===============================================================
#=== train optimization def
#===============================================================
para_a = torch.zeros(1).cuda()
para_b = torch.zeros(1).cuda()
para_c = torch.zeros(1).cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr)

####add dp
# privacy_engine = PrivacyEngine(
#     model,
#     sample_rate=args.batch_size/(n_train_pos+n_train_neg),
#     epochs = end_epoch,
#     target_epsilon = 50,
#     target_delta = 1e-5,
#     max_grad_norm=0.00098,  #0.00098
# )
# privacy_engine.attach(optimizer)
#
# print(f"Using sigma={privacy_engine.noise_multiplier} and C={0.00098}")
#### end dp

def train(epoch):
    # global optimizer, ssize
    global optimizer, para_a, para_b, para_c
    model.train()
    # optimizer = lr_scheduler(optimizer, epoch)
    optimizer = lr_decay_func_AUC(optimizer, epoch)
        
    for batch_idx, (x, y) in enumerate(train_loader):
        bs = y.size(0)
        if args.dataset == 'mnist_binary' or args.dataset == 'fashionmnist_binary':
            x = Variable(x.cuda()).unsqueeze(1).float()
            # x = Variable(x.cuda()).float()
            y_cp_pos = y.cpu().detach().numpy().copy()
            y_cp_pos[y_cp_pos == -1] = 0
            y_cp_neg = y.cpu().detach().numpy().copy()
            y_cp_neg[y_cp_neg == 1] = 0
            y_cp_neg[y_cp_neg == -1] = 1
            y_cp_pos = torch.from_numpy(y_cp_pos).long()
            y_cp_neg = torch.from_numpy(y_cp_neg).long()
            y = y.long()
            y = Variable(y.cuda())
            y_cp_pos = Variable(y_cp_pos.cuda())
            y_cp_neg = Variable(y_cp_neg.cuda())
        else:
            x = Variable(x.cuda())
            y = Variable(y.cuda())
        h1 = model(x)
        if args.model == 'SVM':
            cr_loss = hinge_loss(h1, y)
        elif args.model == 'AUC_LeNet' or args.model == 'AUC_MLP' or args.model == 'AUC_linear':
            loss_1 = (1 / n_train_pos) * ((h1 - para_a) ** 2) * (y_cp_pos.view(-1, 1))  #.detach()
            loss_2 = (1 / n_train_neg) * ((h1 - para_b) ** 2) * (y_cp_neg.view(-1, 1))
            loss_3 = 2 * (1 + para_c) * ((1 / n_train_neg) * h1 * (y_cp_neg.view(-1, 1)) - (1 / n_train_pos) * h1 * (
                y_cp_pos.view(-1, 1))) - (1 / (n_train_neg + n_train_pos)) * (para_c ** 2)
            cr_loss = loss_1 + loss_2 + loss_3
        else:        
            cr_loss = F.cross_entropy(h1, y, reduction='none')
        if args.method == 0:
            loss = torch.mean(cr_loss)
        else:
            print('specify method')
            exit()

        optimizer.zero_grad()
        loss.requres_grad = True
        # loss.backward(retain_graph=True)
        loss.backward()

        # params = (p for p in model.parameters() if p.requires_grad)
        for p in model.parameters():
            if p.requires_grad:
                noise = torch.normal(0,args.sigma*args.grad_w_max, p.grad.shape,device='cuda')
                noise /= args.batch_size
                p.grad += noise

        optimizer.step()

        h1_new = model(x).detach()
        # h1_new.requres_grad = False
        noise_a = torch.normal(0,args.sigma*args.grad_a_max, para_a.shape,device='cuda')
        noise_a /= args.batch_size
        noise_b = torch.normal(0,args.sigma*args.grad_b_max, para_b.shape,device='cuda')
        noise_b /= args.batch_size
        noise_c = torch.normal(0,args.sigma*args.grad_c_max, para_c.shape,device='cuda')
        noise_c /= args.batch_size
        para_a = para_a - args.lr * (1/epoch) * (torch.mean((2/n_train_pos)*(para_a-h1_new)* (y_cp_pos.view(-1, 1)))+noise_a)
        para_b = para_b - args.lr * (1/epoch) * (torch.mean((2 / n_train_neg) * (para_b - h1_new) * (y_cp_neg.view(-1, 1)))+noise_b)
        para_c = para_c + args.lr_c * (1/(epoch**(2/3))) * (torch.mean((2 / n_train_neg) * h1_new * (y_cp_neg.view(-1, 1))\
                                               -(2 / n_train_pos)* h1_new * (y_cp_pos.view(-1, 1))\
                                               -(2 / (n_train_neg + n_train_pos))*para_c)+noise_c)
        if torch.norm(para_c).item() > args.proj_c:
            para_c = para_c*(args.proj_c/torch.norm(para_c).item())
    optimizer.zero_grad()


#===============================================================
#=== train/test output def
#===============================================================    
def output(data_loader):
    if data_loader == train_loader:    
        model.train()
    elif data_loader == test_loader:
        model.eval()
    total_loss = 0    
    total_correct = 0      
    total_size = 0
    AUC = 0
    AUC_size = 0
    for batch_idx, (x, y) in enumerate(data_loader):
        if args.dataset == 'mnist_binary' or args.dataset == 'fashionmnist_binary':
            x= Variable(x.cuda()).unsqueeze(1).float()
            # x = Variable(x.cuda()).float()
            y_cp_pos = y.cpu().detach().numpy().copy()
            y_cp_pos[y_cp_pos==-1]=0
            y_cp_neg = y.cpu().detach().numpy().copy()
            y_cp_neg[y_cp_neg==1]=0
            y_cp_neg[y_cp_neg == -1] = 1
            y_cp_pos = torch.from_numpy(y_cp_pos).long()
            y_cp_neg = torch.from_numpy(y_cp_neg).long()
            y=y.long()
            y = Variable(y.cuda())
            y_cp_pos = Variable(y_cp_pos.cuda())
            y_cp_neg = Variable(y_cp_neg.cuda())
        else:
            x = Variable(x.cuda())
            y = Variable(y.cuda())

        h1 = model(x)
        if args.model == 'SVM':
            total_loss += torch.mean(hinge_loss(h1, y)).item() * y.size(0)
        elif args.model == 'AUC_LeNet' or args.model == 'AUC_MLP' or args.model == 'AUC_linear':
            loss_1=(1/n_train_pos)*((h1-para_a)**2)*(y_cp_pos.view(-1,1))
            loss_2 = (1 / n_train_neg) * ((h1 - para_b) ** 2) * (y_cp_neg.view(-1,1))
            loss_3 = 2*(1+para_c)*((1 / n_train_neg)*h1*(y_cp_neg.view(-1,1))-(1 / n_train_pos)*h1*(y_cp_pos.view(-1,1)))-(1/(n_train_neg+n_train_pos))*(para_c**2)
            auc_loss = loss_1+loss_2+loss_3
            total_loss += torch.mean(auc_loss).item() * y.size(0)
        else:
            total_loss += F.cross_entropy(h1, y).item() * y.size(0)
        AUC += roc_auc_score(y.cpu().detach().numpy(), np.squeeze(h1.cpu().detach().numpy()))
        total_size += y.size(0)
        AUC_size += 1

    # print
    AUC = 100. * AUC/AUC_size
    total_loss /= total_size
    if data_loader == train_loader:    
        # out_str = 'tr_l={:.3f} AUC={:.2f}:'.format(total_loss, AUC)
        out_str = 'tr_l={} AUC={:.3f}:'.format(total_loss, AUC)
    elif data_loader == test_loader:
        # out_str = 'te_l={:.3f} AUC={:.2f}:'.format(total_loss, AUC)
        out_str = 'te_l={} AUC={:.3f}:'.format(total_loss, AUC)
    print(out_str, end=' ')
    filep.write(out_str + ' ') 
    return (total_loss, AUC)

#===============================================================
#=== start computation
#===============================================================    
#== for plot
pl_result = np.zeros((end_epoch+1, 3, 2))  # epoch * (train, test, time) * (loss , acc) 
#== main loop start
time_start = time.time()
for epoch in count(0):
    out_str = str(epoch)
    print(out_str, end=' ') 
    filep.write(out_str + ' ')
    if epoch >= 1:
        train(epoch)
    pl_result[epoch, 0, :] = output(train_loader)
    pl_result[epoch, 1, :] = output(test_loader)
    time_current = time.time() - time_start
    pl_result[epoch, 2, 0] = time_current
    np.save(result_path + '_' + 'pl', pl_result)    
    out_str = 'time={:.1f}:'.format(time_current) 
    print(out_str)    
    filep.write(out_str + '\n')   
    if epoch == end_epoch:
        break