import torch 
from model import NEURAL
import time
import argparse
import os
from dataset import return_data
import ROA

parser = argparse.ArgumentParser()
parser.add_argument('-gpu_id', type=int, default=0, help='id number of the gpu device')
parser.add_argument('-attack_eps', type=int, default=8, help='noise magnitude in adversarial training')
parser.add_argument('-dataset', type=str, default='real_data', help='name of dataset')
args = parser.parse_args()

gpu_id = args.gpu_id
attack_eps = args.attack_eps

device = torch.device('cuda:%d' % gpu_id  if torch.cuda.is_available() else 'cpu')

lr_rate = 0.005
batch_size = 256
n_iters = 50000
main_sensor_path = ''

ckpt_root = 'adv_train_doa_real_5x5_ckpt/'
if not os.path.exists(ckpt_root):
    os.mkdir(ckpt_root)



def test(model, attacker):
    
    correct_adv = 0
    correct = 0
    tot = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        X, GT = inputs, targets
        X = X.to(device)
        
        X_adv = attacker.exhaustive_search(X,GT,0.05,30,5,5,2,2,False)

        Y = model(X)
        Y = torch.argmax(Y,dim=1)

        Y_adv = model(X_adv)
        Y_adv = torch.argmax(Y_adv,dim=1)

        this_batch_size = len(Y)
        
        for i in range(this_batch_size):
            tot+=1
            if GT[i] == Y[i]:
                correct+=1
            if GT[i] == Y_adv[i]:
                correct_adv+=1
    
    print('acc = %d/%d, adv_acc = %d/%d' % (correct,tot, correct_adv,tot))

    return correct_adv/tot



print('[Data] Preparing .... ')
# data = DataMain(batch_size=batch_size)
# data.data_set_up(istrain=True)
# data.greeting()
trainloader, testloader, len_trainset, len_testset = return_data(args.dataset, batch_size, image_size=32)
print('[Data] Done .... ')


print('[Model] Preparing .... ')
model = NEURAL(n_class=8,n_channel=3)
# ckpt = torch.load(main_sensor_path,map_location = 'cpu')
# model.load_state_dict(ckpt)
model = model.to(device)
model.eval()
attacker = ROA.ROA(base_classifier=model,size=32)
adv_acc = test(model,attacker)
print('[Load successfully] initial adv accuracy = %f' % adv_acc)
print('[Model] Done .... ')


if attack_eps >= 16:
    lr_rate = 0.002
    batch_size = 200


loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr_rate, momentum=0.9, weight_decay = 1e-4)

st = time.time()

model.train()
max_acc = 0
save_n = 0
stable_iter = 0
global_iter = 0
out = False

print('[Training] Starting ...')
while not out:
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        global_iter += 1
        
        X, GT = inputs, targets
        
        X = X.to(device)
        GT = GT.to(device)

        # 5x5 sticker for DOA
        X_adv = attacker.exhaustive_search(X,GT,0.05,30,5,5,2,2,False)

        Y = model(X_adv)
        loss = loss_f(Y,GT)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        now = (time.time() - st) / 60.0

        if global_iter % 10 == 0:
            print('[process: %d/%d] Loss = %f' % (global_iter,n_iters,loss))

        if global_iter % 1000 == 0 : # 50
            print(' ### Eval ###')
            print('Time = %f minutes, Iter = %d/%d, Loss = %f' % (now,global_iter,n_iters,loss))
            model.eval()
            
            score = test(model, attacker)

            if score>max_acc:
                print('[save]..')
                max_acc = score
                stable_iter = 0
                torch.save(model.state_dict(), ckpt_root+'model_' + str(save_n) + '_adv_acc=%f.ckpt'%(score))
                save_n+=1
            else:
                stable_iter += 1
                if stable_iter == 10:
                    print('Stable ... Training END ..')
                    out = True
                    break
            model.train()


print('[Training] Done ...')