# -*- coding: utf-8 -*-
"""
"""
from tqdm import tqdm
import argparse
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torchvision import datasets, transforms
import torch.nn as nn
from model.vgg import VGG
from model.resnet import resnet56, resnet110
from model.densenet import DenseNet
from utl import dataset_loader


# Training settings
dataset_options = ['cifar10', 'cifar100','imagenet32']

parser = argparse.ArgumentParser(description=' train')
parser.add_argument('--dataset', '-d', default='cifar10',
                    choices=dataset_options)
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=400, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--gpu', type=int, default=0, metavar='G',
                    help='gpu device number')
parser.add_argument('--opt', type=str, default='sgd')
parser.add_argument('--model', type=str, default='resnet56', help='model')

args = parser.parse_args()

device = torch.device("cuda:"+str(args.gpu))
torch.manual_seed(args.seed)

train_loader, test_loader, data_classes = dataset_loader(args.dataset, args.batch_size)
if args.model == 'Densenet':
    model = DenseNet(growth_rate=12,block_config=[(100 - 4) // 6 for _ in range(3)],num_classes=data_classes,
        small_inputs=True,efficient=True)
elif args.model == 'VGG':
    model = VGG('VGG16',num_classes=data_classes)
elif args.model == 'resnet56':
    model = resnet56(num_classes=data_classes)
elif args.model == 'resnet110':
    model = resnet110(num_classes=data_classes)
    
model.to(device)
criterion = nn.CrossEntropyLoss()

if args.opt =='sgdm':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9)
elif args.opt =='adam':
    optimizer = optim.Adam(model.parameters(), lr=0.001)
elif args.opt =='sgd':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005)
elif args.opt =='Adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=1.0)
elif args.opt =='sgdn':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9, nesterov=True)

if args.dataset =='imagenet32':
    scheduler = MultiStepLR(optimizer, milestones=[30,60,90],gamma=0.1)
else:
    scheduler = MultiStepLR(optimizer, milestones=[50,100],gamma=0.1)


def train(epoch, optimizer):
    model.train()
    correct = 0.
    total = 0.
    loss_avg = 0.
    accuracy = 0.
    progress_bar = tqdm(train_loader)

    for batch_idx, (data, target) in enumerate(progress_bar):

        progress_bar.set_description('Epoch ' + str(epoch))
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = 0.
        loss = criterion(output, target)

        loss.backward()
       	optimizer.step()
        loss_avg += loss.item()
        pred = torch.max(output.detach(), 1)[1]
        total += target.size(0)
        correct += (pred == target.detach()).sum()
        accuracy = correct.item() / total
        progress_bar.set_postfix(
            losserror='%.5f' % (loss_avg / (batch_idx + 1)),
            acc='%.3f' % accuracy)
    return loss_avg / (batch_idx + 1), correct.item() / total


def test(loader):
    model.eval()    # Change model to 'eval' mode (BN uses moving mean/var).
    correct = 0.
    total = 0.
    with torch.no_grad():
        for data, target in loader:
            data,target=data.to(device),target.to(device)
            output = model(data)
            pred = torch.max(output.detach(), 1)[1]
            total += target.size(0)
            correct += (pred == target.detach()).sum().float().item()
        val_acc = correct / total
        print('\nTest set:  Accuracy: {}\n'.format(
            val_acc))
    return val_acc

def evaluteTop5(loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))
            y_resize = y.view(-1,1)
            nul, pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y_resize).sum().float().item()
    val_acc = correct / total
    print('\nTest set: Top5 Accuracy: {}\n'.format(
        val_acc))
    return val_acc


filename = open('largelearningrate/file/'+args.dataset+'_'+args.model+str(args.epochs)+'epochs'+str(args.lr)+str(args.opt)+'.txt', 'w')

if args.dataset == 'imagenet32':
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train(epoch, optimizer)
        var_acc = test(test_loader)
        var_acc5 = evaluteTop5(test_loader)
        filename.write(str(train_loss)+'  '+str('%.4f'%var_acc)+'  '+str('%.4f'%var_acc5)+'  '+str('%.4f'%train_acc)+'\n')
        torch.save(model.state_dict(), 'largelr/file/'+str(args.opt)+str(args.lr)+str(epoch)+'.pkl')
        if args.opt != 'Adadelta':
            scheduler.step(epoch)
else:
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train(epoch, optimizer)
        var_acc = test(test_loader)
        filename.write(str(train_loss)+'  '+str('%.4f'%var_acc)+'  '+str('%.4f'%train_acc)+'\n')
        if args.opt != 'Adadelta':
            scheduler.step(epoch)
filename.close()

