import argparse
import pdb
import sys 
import pickle
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
import torchvision.models as models
from advertorch.utils import NormalizeByChannelMeanStd

from models.resnetv2 import ResNet18
from models.wideresnet import WideResNet
from models.vgg import VGG
from datasets import *
import utils

parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')

########################## base setting ##########################
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
parser.add_argument('--arch', type=str, default='resnet18', help='dataset')
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='adv', type=str)
parser.add_argument('--resume', action="store_true", help="resume from checkpoint")
parser.add_argument('--seed', default=None, type=int, help='random seed')
parser.add_argument('--width_factor', default=10, type=int, help='width-factor of wideresnet')
parser.add_argument('--result_file', help='file name for result', default='result.pkl', type=str)

########################## training setting ##########################
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--decreasing_lr', default='50,150', help='decreasing strategy')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run')

best_prec1 = 0

def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        print('set random seed = ', args.seed)
        utils.setup_seed(args.seed)


    ########################## prepare dataset ##########################
    if args.dataset == 'cifar10':
        print('training on cifar10 dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 10)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 10)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        train_loader, val_loader, test_loader = cifar10_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)

    elif args.dataset == 'cifar100':
        print('training on cifar100 dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 100)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 100, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 100)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.5071, 0.4866, 0.4409], std=[0.2009, 0.1984, 0.2023])
        train_loader, val_loader, test_loader = cifar100_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)
    
    elif args.dataset == 'tinyimagenet':
        print('training on tiny-imagenet dataset')

        if args.arch == 'resnet18':
            model = ResNet18(num_classes = 200)
        elif args.arch == 'wideresnet':
            model = WideResNet(34, 200, widen_factor=args.width_factor, dropRate=0.0)
        else:
            model = VGG(args.arch, num_classes = 200)
        
        model.normal = NormalizeByChannelMeanStd(mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262])
        train_loader, val_loader, test_loader = tiny_imagenet_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)
        
    else:
        print('dataset not support')

    model.cuda()


    ########################## resume ##########################
    start_epoch = 0
    if args.resume:
        print('resume from checkpoint')
        checkpoint = torch.load(os.path.join(args.save_dir, 'checkpoint.pt'), map_location = torch.device('cuda:'+str(args.gpu)))
        best_prec1 = checkpoint['best_prec1']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])


    ########################## optimizer and scheduler ##########################
    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)



    ########################## record result ##########################
    
    all_result = {}
    all_train_acc = []
    all_val_sa = []
    all_test_sa = []


    os.makedirs(args.save_dir, exist_ok=True)


    ########################## training process ##########################
    for epoch in range(start_epoch, args.epochs):

        print(optimizer.state_dict()['param_groups'][0]['lr'])
        train_loss, train_acc = utils.train_epoch(train_loader, model, criterion, optimizer, epoch)

        all_train_acc.append(train_acc)
        scheduler.step()

        ###validation###
        val_loss, val_sa = utils.validate(val_loader, model, criterion)
        test_loss, test_sa = utils.validate(test_loader, model, criterion)
            

        all_val_sa.append(val_sa)
        all_test_sa.append(test_sa)

        is_sa_best = val_sa  > best_prec1
        best_prec1 = max(val_sa, best_prec1)

        if is_sa_best:
            
            utils.save_checkpoint(
                args.save_dir,
                epoch + 1,
                sa_best=True,
                state_dict=model.state_dict(),
                best_prec1 = best_prec1,
                optimizer = optimizer.state_dict(),
                scheduler =  scheduler.state_dict(),
            )

        utils.save_checkpoint(
            args.save_dir,
            epoch + 1,
            sa_best=False,
            state_dict=model.state_dict(),
            best_prec1 = best_prec1,
            optimizer = optimizer.state_dict(),
            scheduler =  scheduler.state_dict(),
        )

        plt.plot(all_train_acc, label='train_acc')
        plt.plot(all_test_sa, label='SA')
        plt.legend()
        plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
        plt.close()


        all_result['train'] = all_train_acc
        all_result['test_sa'] = all_test_sa
        all_result['val_sa'] = all_val_sa

        pickle.dump(all_result, open(os.path.join(args.save_dir, args.result_file),'wb'))



if __name__ == '__main__':
    main()


