import argparse
import torch
import os
import os.path as osp
from tqdm import tqdm
from collections import OrderedDict
from dataloader_v2 import HandLanbanDataset

from utils.logger  import setup_logger
from utils.validation import compute_accurate_num,compute_horizontal_accurate_num,compute_vertical_accurate_num
from utils.distributed_train_util import process_input2device, process_output2device

from models.HandLabanNet_v31 import ResNet50_SPA
from models.HandLabanNet_v32 import ResNet152
from models.HandLabanNet_v33 import ResNet152_CMP
from models.HandLabanNet_v34 import MultiViewAttention
from models.HandLabanNet_v35 import PreMultiViewAttention
from models.HandLabanNet_v36 import PreMultiViewAttentionCMP

root_dir = os.getcwd()
data_dir = osp.join(root_dir, 'data')
log_file_dir = osp.join(data_dir,'logs/test')

parser = argparse.ArgumentParser(description='Testing')

parser.add_argument('--checkpoint_path', required = True, type = str)
parser.add_argument('--batch_size', default = 1, type = int)
parser.add_argument('--num_workers', default = 4, type = int)
parser.add_argument('--clip_step', default = 10, type = int)
parser.add_argument('--log_file_path', default = log_file_dir, type = str)
parser.add_argument('--data_file_path', default = data_dir, type = str)
parser.add_argument('--epochs', type = int, required = True)
parser.add_argument('--lr', type = float, required = True)
parser.add_argument('--model', type = str, required = True)
parser.add_argument('--loss', type = str, required = True)
parser.add_argument('--split', type = str, required = True)

def main():
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not osp.exists(args.log_file_path):
        os.makedirs(args.log_file_path)
    log_path = osp.join(args.log_file_path, args.model + '_' + args.loss + '_epoch_' + str(args.epochs) + '_lr_' + '_test_result_' +'.log')
    logger = setup_logger(output = log_path, name = 'Testing')
    
    logger.info('Start testing:')
    logger.info('Creating test dataset.')
    test_dataset = HandLanbanDataset(args.data_file_path, split = args.split, aug = False, clip_step = args.clip_step)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch_size, num_workers = args.num_workers)
    logger.info('The test dataset is created successfully. ')
    if args.model == 'ResNet50':
        model = ResNet50_SPA().to(device)
    elif args.model == 'ResNet152':
        model = ResNet152().to(device)
    elif args.model == 'ResNet152_CMP':
        model = ResNet152_CMP().to(device)
    elif args.model == 'MultiViewAttention':
        model = MultiViewAttention()
    elif args.model == 'Debug':
        model = PreMultiViewAttention()
    elif args.model == 'PreMultiViewAttention':
        model = PreMultiViewAttention()
    elif args.model == 'PreMultiViewAttentionCMP':
        model = PreMultiViewAttentionCMP()
    
    logger.info('The test model is: {}'.format(args.model))
    logger.info('The test model is created successfully. ')
    checkpoint = torch.load(args.checkpoint_path)
    logger.info('The checkpoint is: {}'.format(args.checkpoint_path))
    logger.info('The epochs of test mode: {}'.format(args.epochs))
    logger.info('the LR:{}'.format(args.lr))
    
    new_checkpoint = OrderedDict()
    for k, v in checkpoint.items():
        name = k[7:]
        new_checkpoint[name] = v
    
    model.load_state_dict(new_checkpoint)
    model = model.to(device)
    logger.info('The model is load successfully')
    model.eval()
    total_correct_nums = 0
    total_correct_horizontal_nums = 0
    total_correct_vertical_nums = 0
    for iteration, (inputs,targets) in enumerate(tqdm(test_loader)):
        with torch.no_grad():
            inputs = process_input2device(inputs=inputs,device=device)
            targets = process_output2device(targets=targets,device=device)
            outputs = model(inputs)
            total_correct_nums += compute_accurate_num(outputs = outputs,  targets = targets)
            total_correct_horizontal_nums += compute_horizontal_accurate_num(outputs = outputs,  targets = targets)
            total_correct_vertical_nums += compute_vertical_accurate_num(outputs = outputs,  targets = targets)
        
    total_laban_acc = float(total_correct_nums)/len(test_loader)
    horizontal_laban_acc = float(total_correct_horizontal_nums)/len(test_loader)
    vertical_laban_acc = float(total_correct_vertical_nums)/len(test_loader)
    
    acc_dict = {'total_laban_acc': total_laban_acc, 'horizontal_laban_acc': horizontal_laban_acc, 'vertical_laban_acc': vertical_laban_acc}
        
    logger.info('The accuracy of test set is: total_laban_acc : {}, horizontal_laban_acc: {}, vertical_laban_acc: {}'
                .format(acc_dict['total_laban_acc'], acc_dict['horizontal_laban_acc'], acc_dict['vertical_laban_acc']))

if __name__ == '__main__':
    main()
    
    