import argparse
import logging
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm

from dataset.imagenet import DATASET_GETTERS
from utils import AverageMeter, accuracy

logger = logging.getLogger(__name__)

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def main():
    parser = argparse.ArgumentParser(description='PyTorch FixMatch Training')
    parser.add_argument("--cam", action="store_true",
                        help="use cam for cutout augmentation")
    parser.add_argument('--num_workers', type=int, default=8,
                        help='number of workers')
    parser.add_argument('--dataset', default='imagenet', type=str,
                        help='dataset name')
    parser.add_argument('--num_labeled', type=int, default=100000,
                        help='number of labeled data')
    parser.add_argument('--batch_size', default=256, type=int,
                        help='train batchsize')
    parser.add_argument('--use_ema', action='store_true', default=True,
                        help='use EMA model')
    parser.add_argument('--ema_decay', default=0.999, type=float,
                        help='EMA decay rate')
    parser.add_argument('--checkpoint', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed', default=None, type=int,
                        help="random seed")
    parser.add_argument('--no_progress', action='store_true',
                        help="don't use progress bar")

    args = parser.parse_args()


    # ImageNet-100K
    args.num_classes = 1000

    def create_model(args):
        # ResNet50 for ImageNet-100K
        import models.resnet50 as models
        model = models.build_ResNet50(num_classes=args.num_classes)
        logger.info("Total params: {:.2f}M".format(
            sum(p.numel() for p in model.parameters())/1e6))
        return model

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)

    logger.warning(f"device: {args.device}, ")

    logger.info(dict(args._get_kwargs()))

    if args.seed is not None:
        set_seed(args)

    labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, './data')

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    model = create_model(args)
    model.to(args.device)

    if args.use_ema:
        from models.ema import ModelEMA
        ema_model = ModelEMA(args, model, args.ema_decay)

    logger.info("==> Loading checkpoint..")
    assert os.path.isfile(
        args.checkpoint), "Error: no checkpoint directory found!"
    args.out = os.path.dirname(args.checkpoint)
    checkpoint = torch.load(args.checkpoint)
    record_best_acc = checkpoint['best_acc']
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    if args.use_ema:
        ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])

    logger.info("***** Running Evaluation *****")
    logger.info(f"  Task = {args.dataset}@{args.num_labeled}")
    logger.info(f"  Batch size = {args.batch_size}")
    logger.info(f"  Record test accuracy = {record_best_acc}")

    if args.use_ema:
        test_model = ema_model.ema
    else:
        test_model = model
    
    test(args, test_loader, test_model)


def test(args, test_loader, model):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if not args.no_progress:
        test_loader = tqdm(test_loader)

    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            if not args.no_progress:
                test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                    batch=batch_idx + 1,
                    iter=len(test_loader),
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))
        if not args.no_progress:
            test_loader.close()

    logger.info("top-1 acc: {:.2f}".format(top1.avg))
    logger.info("top-5 acc: {:.2f}".format(top5.avg))
    return losses.avg, top1.avg


if __name__ == '__main__':
    main()
