import argparse
from utils import *
import torch.nn as nn
from thop import profile
from model import Network
from datetime import datetime
import torch.nn.functional as F
from torchsummary import summary
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from collections import OrderedDict

parser = argparse.ArgumentParser('Train signal model')
parser.add_argument('--model_path', type=str, required=True, help='path to save model')
parser.add_argument('--classes', type=int, default=10, help='num of MB_layers')
parser.add_argument('--layers', type=int, default=12, help='num of MB_layers')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=600, help='num of epochs')
parser.add_argument('--seed', type=int, default=0, help='seed')
parser.add_argument('--auxiliary', action='store_true', default=False, help='auxiliary weight')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--gpu', type=int, default=0, help='gpu id')
# ******************************* dataset *******************************#
parser.add_argument('--dataset', type=str, default='cifar10', help='[cifar10, imagenet]')
parser.add_argument('--data_dir', type=str, default='/home/work/dataset/cifar', help='dataset dir')
parser.add_argument('--colorjitter', action='store_true', default=False, help='use colorjitter')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--resize', action='store_true', default=False, help='use resize')

args = parser.parse_args()
print(args)


def validate(epoch, val_data, device, model):
    model.eval()
    val_loss = 0.0
    val_top1 = AvgrageMeter()
    val_top5 = AvgrageMeter()
    criterion = nn.CrossEntropyLoss().to(device)
    with torch.no_grad():
        for step, (inputs, targets) in enumerate(val_data):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            n = inputs.size(0)
            val_top1.update(prec1.item(), n)
            val_top5.update(prec5.item(), n)

    return val_top1.avg, val_top5.avg, val_loss / (step + 1)


def main(choice):
    # seed
    set_seed(args.seed)

    # device
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        device = torch.device("cuda")

    model = Network(choice=choice, auxiliary=args.auxiliary)
    model = model.to(device)
    checkpoint = torch.load(args.model_path, map_location=device)
    model.load_state_dict(checkpoint, strict=True)
    flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).to(device),), verbose=False)
    print('FLOPs: {}, params: {}'.format(flops / 1e6, params / 1e6))

    train_transform, valid_transform = data_transforms_cifar(args)
    valset = dset.CIFAR10(root=args.data_dir, train=False, download=False, transform=valid_transform)
    valid_queue = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
                                              shuffle=False, pin_memory=True, num_workers=8)

    # validate
    val_top1, val_top5, val_obj = validate(0, val_data=valid_queue, device=device, model=model)

    print('val: loss={}, top1={}, top5={}'.format(val_obj, val_top1, val_top5))


if __name__ == '__main__':
    start = time.time()
    choice = OrderedDict(
        [(0, {'conv': [3, 1], 'rate': 0}), (1, {'conv': [1, 3], 'rate': 1}), (2, {'conv': [3, 0], 'rate': 0}),
         (3, {'conv': [2, 1], 'rate': 1}), (4, {'conv': [2], 'rate': 1}), (5, {'conv': [3, 0], 'rate': 1}),
         (6, {'conv': [2], 'rate': 1}), (7, {'conv': [0, 2], 'rate': 1}), (8, {'conv': [2, 0], 'rate': 1}),
         (9, {'conv': [0], 'rate': 0}), (10, {'conv': [0, 2], 'rate': 1}), (11, {'conv': [2], 'rate': 1})])
    print('choice_model:', choice)
    main(choice)
