import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from models import MixPath_A, MixPath_B
from utils import AvgrageMeter, accuracy
from thop import profile


def get_args():
    parser = argparse.ArgumentParser("Transfer Learning")
    parser.add_argument('--model', type=str, help='model to transfer')
    parser.add_argument('--model_path', type=str, help='path to save models')
    parser.add_argument('--data_dir', type=str, help='path to save dataset')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--gpu_id', type=int, default=0, required=True, help='gpu id')
    args = parser.parse_args()
    print(args)
    return args


def validate(epoch, val_data, device, model):
    model.eval()
    val_loss = 0.0
    val_top1 = AvgrageMeter()
    val_top5 = AvgrageMeter()
    criterion = nn.CrossEntropyLoss().to(device)
    with torch.no_grad():
        for step, (inputs, targets) in enumerate(val_data):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            n = inputs.size(0)
            val_top1.update(prec1.item(), n)
            val_top5.update(prec5.item(), n)

    return val_top1.avg, val_top5.avg, val_loss / (step + 1)


def main():
    args = get_args()
    # device
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        torch.cuda.set_device(args.gpu_id)
        cudnn.benchmark = True
        cudnn.enabled = True
        device = torch.device("cuda")

    if args.model == 'mixpath_a':
        model = MixPath_A()
        state = torch.load(args.model_path)['model_state']
        model.load_state_dict(state)
    if args.model == 'mixpath_b':
        model = MixPath_B()
        state = torch.load(args.model_path)['model_state']
        model.load_state_dict(state)
    model = model.to(device)
    flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224).cuda(),), verbose=False)
    print('flops: {}M, params: {}M'.format(flops / 1e6, params / 1e6))

    # load dataset
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])
    valset = datasets.ImageFolder(root=args.data_dir + '/val', transform=val_transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size, shuffle=False,
                                            num_workers=8, pin_memory=False)
    # validate
    val_top1, val_top5, val_obj = validate(0, val_data=valloader, device=device, model=model)
    print('val: loss={}, top1={}, top5={}'.format(val_obj, val_top1, val_top5))


if __name__ == '__main__':
    main()
