import os
import sys
sys.path.insert(0, '..')
import argparse
import numpy as np
import cv2
from PIL import Image
import logging
import time

import torch
import torch.utils.data
import torchvision.transforms as transforms

from functions import *
from network import StylizedFacePoint

"""
import warnings
warnings.filterwarnings("ignore")
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Test StylizedFacePoint")
    parser.add_argument("--experiment_name", type=str, default="Exp1", help="the name of experiment")
    parser.add_argument("--data_name", type=str, default="FLSC", help="the name of dataset")
    parser.add_argument("--dataset_url", type=str, default="data/", help="the path of dataset")
    parser.add_argument("--test_labels", type=str, default="test.txt", help="the name of test labels file in dataset")
    parser.add_argument("--test_images", type=str, default="images_test", help="the name of test images folder in dataset")
    parser.add_argument("--num_lms", type=int, default=98, help="the number of landmarks in dataset")
    parser.add_argument("--num_nb", type=int, default=3, help="the number of neighbor landmarks")
    parser.add_argument("--input_size", type=int, default=256, help="the size of input images")
    parser.add_argument("--net_stride", type=int, default=16, help="the stride of network")
    parser.add_argument("--nstack", type=int, default=4, help="the number of stages in stacked hourglass network")
    parser.add_argument("--use_gpu", action="store_true", help="use gpu for model evaluation")
    parser.add_argument("--gpu_id", type=int, default=0, help="the index of gpu(if use)")
    args = parser.parse_args()

    if not os.path.exists(os.path.join('./logs', args.data_name)):
        os.mkdir(os.path.join('./logs', args.data_name))
    log_dir = os.path.join('./logs', args.data_name, args.experiment_name)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    logging.basicConfig(filename=os.path.join(log_dir, 'test.log'), level=logging.INFO)

    save_dir = os.path.join('./snapshots', args.data_name, args.experiment_name)

    meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join(args.dataset_url, args.data_name, 'meanface.txt'), args.num_nb)

    if args.use_gpu:
        device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    net = StylizedFacePoint(args, device)
    net = net.to(device)

    weight_file = os.path.join(save_dir, "best_final.pth")
    state_dict = torch.load(weight_file)
    net.load_state_dict(state_dict)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    preprocess = transforms.Compose([transforms.Resize((args.input_size, args.input_size)), transforms.ToTensor(), normalize])

    norm_indices = [60, 72]

    labels = get_label(args.dataset_url, args.data_name, args.test_labels)

    nmes = []
    norm = None
    time_all = 0
    for label in labels:
        image_name = label[0]

        lms_gt = label[1][:-1]
        if len(label[1]) == 197:
            lms_gt = label[1][:-1]
        else:
            lms_gt = label[1]

        norm = np.linalg.norm(lms_gt.reshape(-1, 2)[norm_indices[0]] - lms_gt.reshape(-1, 2)[norm_indices[1]])
        
        image_path = os.path.join(args.dataset_url, args.data_name, args.test_images, image_name)
        # print(image_name)
        image = cv2.imread(image_path)

        inputs = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
        inputs = preprocess(inputs).unsqueeze(0)
        inputs = inputs.to(device)

        t1 = time.time()
        lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_net(net, inputs, args.input_size, args.net_stride, args.num_lms, args.num_nb)

        # merge neighbor predictions
        tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(args.num_lms, max_len)
        tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(args.num_lms, max_len)
        tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1,1)
        tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1,1)
        tmp_len = len(tmp_x)
        
        lms_pred = torch.cat((tmp_x, tmp_y), dim=1).flatten()

        t2 = time.time()
        time_all += (t2 - t1)

        lms_pred = lms_pred.cpu().numpy()

        nme = compute_nme(lms_pred, lms_gt, norm)
        nmes.append(nme)

    print('Total inference time:', time_all)
    print('Image num:', len(labels))
    print('Average inference time:', time_all/len(labels))

    print('nme: {}'.format(np.mean(nmes)))
    logging.info('nme: {}'.format(np.mean(nmes)))

    fr, auc = compute_fr_and_auc(nmes)
    print('fr : {}'.format(fr))
    logging.info('fr : {}'.format(fr))
    print('auc: {}'.format(auc))
    logging.info('auc: {}'.format(auc))