from __future__ import print_function
import random

import time
import argparse
import os
import sys
import shutil
import pprint

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.model import WideResnet, WideResnetLarge

from utils import accuracy

from utils import AverageMeter

def get_params(dataset):
    if dataset == 'CIFAR10' or dataset == 'SVHN':
        n_classes = 10
        wresnet_k = 2
        wresnet_n = 28
    elif dataset == 'CIFAR100':
        n_classes = 100
        wresnet_k = 2
        wresnet_n = 28
    elif dataset == 'STL10':
        n_classes = 10
        wresnet_k = 2
        wresnet_n = 28
    elif dataset == 'TinyImageNet':
        n_classes = 200
        wresnet_k = 2
        wresnet_n = 28

    return n_classes, wresnet_k, wresnet_n


def set_model(n_classes, wresnet_k, wresnet_n, stl=False, large=False):
    '''
    stl == True --> training STL-10 dataset
    large == True --> WideResnetLarge is used
    '''

    if large:
        model = WideResnetLarge(n_classes=n_classes, k=wresnet_k, n=wresnet_n)
    else:
        model = WideResnet(n_classes=n_classes, k=wresnet_k, n=wresnet_n, stl=stl)

    model.train()
    model.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    return model, criterion


def get_data(args):
    if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100':
        from datasets.cifar import get_test_loader
    elif args.dataset == 'SVHN':
        from datasets.svhn import get_test_loader
    elif args.dataset == 'STL10':
        from datasets.stl10 import get_test_loader
    elif args.dataset == 'TinyImageNet':
        from datasets.tiny_imagenet import get_test_loader

    dltest = get_test_loader(dataset=args.dataset, batch_size=args.valbatchsize, num_workers=2)

    return dltest


def evaluate(model, dataloader, criterion):
    model.eval()

    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    # matches = []
    with torch.no_grad():
        for ims, lbs in dataloader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits, _ = model(ims)
            loss = criterion(logits, lbs)
            scores = torch.softmax(logits, dim=1)
            top1, top5 = accuracy(scores, lbs, (1, 5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    print("Test. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(top1_meter.avg, top5_meter.avg, loss_meter.avg))

    return top1_meter.avg, top5_meter.avg, loss_meter.avg


#############################################################################################
# Options
#############################################################################################
parser = argparse.ArgumentParser(description='Semi-supervised Learning')

parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')

parser.add_argument('--backbone', type=str, default='wideresnet', help='Wideresnet')
parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet')
parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet')
parser.add_argument('--large_model', action='store_true', help='default is False. If True, using WideResnetLarge model')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='CIFAR10, CIFAR100, SVHN, STL10, or TinyImageNet')

parser.add_argument('--valbatchsize', default=100, type=int, help='validation batch size')
    
parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive')

parser.add_argument('--checkpoint', type=str, default='', help='checkpoint to run evaluation')
#############################################################################################


def main():
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids

    # global settings
    if args.seed > 0:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

    args.n_classes, args.wresnet_k, args.wresnet_n = get_params(args.dataset)

    print(pprint.pformat(args))

    print('Loading model {}...'.format(args.backbone))
    model, criterion = set_model(args.n_classes, args.wresnet_k, args.wresnet_n, stl=True) if ((args.dataset == 'STL10') or (args.dataset == 'TinyImageNet')) else set_model(args.n_classes, args.wresnet_k, args.wresnet_n, large=args.large_model)
    print("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))

    state = torch.load(args.checkpoint)
    model.load_state_dict(state)
    print('{} loaded'.format(args.checkpoint))

    dltest = get_data(args)

    print('start validation.')
    top1, top5, test_loss = evaluate(model, dltest, criterion)


if __name__ == '__main__':
    main()
