import time

import torch
import torchvision

from data import prepare_val_loader, AverageMeter, accuracy
from densenet import densenet201


if __name__ == '__main__':
    print("Simulate a victim")
    print("This demo has been tested with pytorch>=1.4.0 torchvision>=0.5.0\n")

    loader = prepare_val_loader(batch_size=10, num_workers=1)

    print("Run an network downloaded from torchvision")
    time.sleep(1)
    model  = torchvision.models.densenet201(pretrained=True)
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            print('iteration', idx)
            out = model(x)
            acc1 = accuracy(out, y, topk=(1,))
            top1.update(acc1[0], x.size(0))
            if idx == 3:
                break
    print(f'Average accuracy on 40 images {float(top1.avg):.1f}%')

    print("\nRun a stego network infected with malware")
    time.sleep(1)
    model = densenet201()
    model.load_state_dict(torch.load('mal_model.pth'))
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            print('iteration', idx)
            out = model(x)
            acc1 = accuracy(out, y, topk=(1,))
            top1.update(acc1[0], x.size(0))
            if idx == 3:
                break
    print(f'Average accuracy on 40 images {float(top1.avg):.1f}%')
