import fire
import torch
from lenet import mnist_dataset, load_lenet, get_device
from vgg import cifar10_dataset, load_vgg


def score_model(model_dir, data_loc, model_type):

    device = get_device()
    if model_type == "lenet":
        _, valloader = mnist_dataset(data_loc)
        model = load_lenet(model_dir).to(device)
    elif model_type == "vgg":
        _, valloader = cifar10_dataset(data_loc)
        model = load_vgg(model_dir).to(device)
    else:
        raise Exception(f"Unknown model type {model_type}")

    correct_count, all_count = 0, 0
    for images, labels in valloader:
        images = images if model_type == "vgg" else images.view(images.shape[0], -1)
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            logps = model(images)

        pred_labels = logps.argmax(dim=1)
        correct_count += (pred_labels == labels).sum().item()
        all_count += pred_labels.shape[0]

    print("Number Of Images Tested =", all_count)
    print("\nModel Accuracy =", (correct_count/all_count))


if __name__ == "__main__":
    fire.Fire(score_model)
