import time
import os
import torch
import numpy as np
from model import EnsembleModel
from data.datamgr import SetDataManager
from utils import *
import argparse
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, RobertaModel


def test(model , params):
    
    test_log = {}
    test_log['params'] = vars(params)
    test_log['test_accs'] = []
    start_time = time.time()

    # load data
    test_mgr = SetDataManager(params.dataset,params.data_dir, params.image_size, params.n_way, params.n_shot, params.n_query,
                             1000)
    test_loader = test_mgr.get_data_loader('test', False)
    # load model
    model_file = os.path.join(params.model_dir, params.model_name)
    print(model_file)
    load_model(model, model_file)
    
    for param in model.module.text_encoder.roberta_model.parameters():
        param.requires_grad_(False)

    model.eval()

    _ , test_acc , test_ci = model.module.test_loop(test_loader, params)
    test_log['test_acc'] = test_acc
    
    test_time = time.time() - start_time
    print(f'[ log ] roughly {test_time / 3600:.3f} h left\n')

    torch.save(test_log, os.path.join(params.model_dir, 'test_log'))
    outfile = os.path.join(params.model_dir, 'test_model.tar')
    torch.save({'state': model.state_dict()}, outfile)
    
    print("Test:")
    print(f'[final] epo:{"best":>3} | {test_acc:.4f} +- {test_ci:.3f}')


if __name__ == '__main__':
    params = setup_run()

    ''' define roberta model '''    
    roberta_config = AutoConfig.from_pretrained('roberta-base')
    roberta_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
    roberta_model = RobertaModel.from_pretrained('roberta-base', roberta_config)

    model = EnsembleModel(roberta_model, roberta_tokenizer, roberta_config, model_dict[params.backbone])
    model = nn.DataParallel(model, device_ids=params.device_ids)
    model.cuda()

    test(model, params)