from argparse import ArgumentParser
from collections import OrderedDict

import matplotlib.pyplot as plt
import torch
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16
from torchvision.transforms import ToTensor


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--model-path", help='path of stored model, end with .ckpt')
    parser.add_argument("--num-classes", type=int, default=30)
    parser.add_argument("--data-root", type=str, help='path of evaluation dataset', default='none')
    parser.add_argument("--gpu-id", help='GPU device ID', type=int, default=0)
    parser.add_argument("--eval-background", action='store_true')
    return parser.parse_args()


def load_model(model_path, num_classes, device):
    checkpoint = torch.load(model_path)
    model_checkpoint = checkpoint["state_dict"]
    model = vgg16(num_classes=num_classes).to(device).eval()
    adjusted_model_checkpoint = OrderedDict()
    for key in model_checkpoint.keys():
        adjusted_model_checkpoint[".".join(key.split(".")[1:])] = model_checkpoint[key]
    model.load_state_dict(adjusted_model_checkpoint)
    return model


def eval_model_on_background(model, device):
    background_tensor = torch.zeros(128, 3, 224, 224).to(device)
    logits = model(background_tensor)
    var, mean = torch.var_mean(logits, dim=0)
    plt.errorbar(range(0, 30), mean.detach().cpu(), var.detach().cpu(), linestyle='None', marker=".")
    plt.ylim(-3, 3)
    plt.show()


def eval_model_on_dataset(model, data_root, num_classes, device):
    metric = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
    dataset = ImageFolder(root=data_root, transform=ToTensor())
    dataloader = DataLoader(dataset, batch_size=10, num_workers=1)
    for batch in dataloader:
        x, y = batch
        y_hat = model(x.to(device))
        _ = metric(y_hat.detach().cpu(), y)
    acc = metric.compute()
    print(f"Accuracy on all data: {acc}")


def main():
    args = parse_args()
    model = load_model(args.model_path, args.num_classes, f"cuda:{args.gpu_id}")
    if args.eval_background:
        eval_model_on_background(model, f"cuda:{args.gpu_id}")
    if args.data_root != "none":
        eval_model_on_dataset(model, args.data_root, args.num_classes, f"cuda:{args.gpu_id}")


if __name__ == "__main__":
    main()
