import os
import sys
sys.path.insert(0, '..')
import argparse
import numpy as np
import cv2
from PIL import Image
import time

import torch
import torch.utils.data
import torchvision.transforms as transforms

from functions import *
from network import StylizedFacePoint

"""
import warnings
warnings.filterwarnings("ignore")
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Detect landmarks")
    parser.add_argument("--image_path", type=str, help="the path of image to be detected")
    parser.add_argument("--experiment_name", type=str, default="Exp1", help="the name of experiment")
    parser.add_argument("--data_name", type=str, default="FLSC", help="the name of dataset")
    parser.add_argument("--dataset_url", type=str, default="data/", help="the path of dataset")
    parser.add_argument("--num_lms", type=int, default=98, help="the number of landmarks in dataset")
    parser.add_argument("--num_nb", type=int, default=3, help="the number of neighbor landmarks")
    parser.add_argument("--input_size", type=int, default=256, help="the size of input images")
    parser.add_argument("--net_stride", type=int, default=16, help="the stride of network")
    parser.add_argument("--nstack", type=int, default=4, help="the number of stages in stacked hourglass network")
    parser.add_argument("--use_gpu", action="store_true", help="if use gpu for model evaluation")
    parser.add_argument("--gpu_id", type=int, default=0, help="the index of gpu(if use)")
    parser.add_argument("--detect_num_lms", type=int, default=98, help="the number of landmarks in detected results(98/68)")
    parser.add_argument("--vis_res", action="store_true", help="if result visualization")
    parser.add_argument("--save_dir", type=str, default="result/", help="the save dir of visualization results")
    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    model_dir = os.path.join('./snapshots', args.data_name, args.experiment_name)

    meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join(args.dataset_url, args.data_name, 'meanface.txt'), args.num_nb)

    if args.use_gpu:
        device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    net = StylizedFacePoint(args, device)
    net = net.to(device)

    weight_file = os.path.join(model_dir, "best_final.pth")
    state_dict = torch.load(weight_file)
    net.load_state_dict(state_dict)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    preprocess = transforms.Compose([transforms.Resize((args.input_size, args.input_size)), transforms.ToTensor(), normalize])


    # print(image_name)
    image = cv2.imread(args.image_path)

    inputs = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
    inputs = preprocess(inputs).unsqueeze(0)
    inputs = inputs.to(device)

    lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_net(net, inputs, args.input_size, args.net_stride, args.num_lms, args.num_nb)

    # merge neighbor predictions
    tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(args.num_lms, max_len)
    tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(args.num_lms, max_len)
    tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1,1)
    tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1,1)
    tmp_len = len(tmp_x)

    lms_pred = torch.cat((tmp_x, tmp_y), dim=1).flatten()
    lms_pred = lms_pred.cpu().numpy() * args.input_size

    if args.detect_num_lms == 68:
        x_98 = []
        y_98 = []
        for i in range(98):
            x = lms_pred[2 * i]
            y = lms_pred[2 * i + 1]
            x_98.append(x)
            y_98.append(y)

        x_68 = retarget(x_98)
        y_68 = retarget(y_98)

        lms_pred = np.zeros(68 * 2)
        for i in range(68):
            lms_pred[2 * i] = x_68[i]
            lms_pred[2 * i + 1] = y_68[i]

    np.save(os.path.join(args.save_dir, 'landmarks_{}.npy'.format(args.vis_num_lms)), lms_pred)

    if args.vis_res:
        for i in range(args.detect_num_lms):
            x = lms_pred[2 * i]
            y = lms_pred[2 * i + 1]
            cv2.circle(image, (int(x), int(y)), 1, (255, 0, 0), 2)
        cv2.imwrite(os.path.join(args.save_dir, 'vis_landmarks_{}.jpg'.format(args.vis_num_lms)), image)

    print("Detected Finished!")
    print("Landmark results are saved in {}".format(args.save_dir))
