import os

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch.backends import cudnn as cudnn
from torch.utils.data import DataLoader

import eval_util
import util
from dataset.imagenet import MetaDataset, ImageNet, TieredImageNet
from meta_labeler import MetaLabeler
from models.util import create_model
from routines import parse_option


@hydra.main(config_path="config", config_name="meta_eval.yaml")
def eval_main(opt: DictConfig):
    OmegaConf.set_struct(opt, False)
    opt = parse_option(opt)

    logger = util.get_logger(opt.logger_name, file_name=f"{opt.logger_name}_{opt.model_name}")
    logger.info(opt)

    opt.n_shots = 1
    if opt.dataset == 'miniImageNet':
        meta_testloader = DataLoader(
            MetaDataset(ImageNet(opt, partition="test"), args=opt, db_size=opt.test_db_size),
            batch_size=1, shuffle=False, drop_last=False,
            num_workers=opt.num_workers)
    elif opt.dataset == 'tieredImageNet':
        meta_testloader = DataLoader(
            MetaDataset(TieredImageNet(opt, partition="test"), opt, db_size=opt.test_db_size),
            batch_size=1, shuffle=False, drop_last=False,
            num_workers=opt.num_workers)
    else:
        raise NotImplementedError(opt.dataset)

    backbone = create_model(opt.model, -1, dataset=opt.dataset)

    if torch.cuda.is_available():
        cudnn.benchmark = True

    model = MetaLabeler(backbone, opt, opt.feat_dim, intercept=False)
    model = model.cuda()
    tmp_path = os.path.join(opt.model_path, opt.pretrained_model)
    state_dict = torch.load(tmp_path)["model"]
    util.partial_reload(model, state_dict)

    mean, confidence = eval_util.meta_test(model, meta_testloader, opt)
    logger.info(f"1-shot Acc: {mean}, Std: {confidence}")

    meta_testloader.dataset.n_shots = 5
    meta_testloader.dataset.n_per_class += 4
    mean, confidence = eval_util.meta_test(model, meta_testloader, opt)
    logger.info(f"5-shot Acc: {mean}, Std: {confidence}")


if __name__ == '__main__':
    eval_main()