import argparse
import glob
import numpy as np
import os
import torch
import torchvision.transforms as transforms
from skimage import io

from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.utils import imwrite, tensor2img

try:
    from facexlib.utils.face_restoration_helper import FaceRestoreHelper
except ImportError:
    print('Please install facexlib: pip install facexlib')

# TODO： need to modify, as we have updated the FaceRestorationHelper


def get_part_location(landmarks):
    """Get part locations from landmarks."""
    map_left_eye = list(np.hstack((range(17, 22), range(36, 42))))
    map_right_eye = list(np.hstack((range(22, 27), range(42, 48))))
    map_nose = list(range(29, 36))
    map_mouth = list(range(48, 68))

    # left eye
    mean_left_eye = np.mean(landmarks[map_left_eye], 0)  # (x, y)
    half_len_left_eye = np.max(
        (np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16))  # A number
    loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
    loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
    # (1, 4), the four numbers forms two  coordinates in the diagonal

    # right eye
    mean_right_eye = np.mean(landmarks[map_right_eye], 0)
    half_len_right_eye = np.max(
        (np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16))
    loc_right_eye = np.hstack(
        (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
    loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
    # nose
    mean_nose = np.mean(landmarks[map_nose], 0)
    half_len_nose = np.max(
        (np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16))  # noqa: E126
    loc_nose = np.hstack((mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
    loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
    # mouth
    mean_mouth = np.mean(landmarks[map_mouth], 0)
    half_len_mouth = np.max(
        (np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16))  # noqa: E126
    loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
    loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)

    return loc_left_eye, loc_right_eye, loc_nose, loc_mouth


if __name__ == '__main__':
    """We try to align to the official codes. But there are still slight
    differences: 1) we use dlib for 68 landmark detection; 2) the used image
    package are different (especially for reading and writing.)
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    parser = argparse.ArgumentParser()

    parser.add_argument('--upscale_factor', type=int, default=2)
    parser.add_argument(
        '--model_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
    parser.add_argument(
        '--dict_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
    parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
    parser.add_argument('--upsample_num_times', type=int, default=1)
    parser.add_argument('--save_inverse_affine', action='store_true')
    parser.add_argument('--only_keep_largest', action='store_true')
    # The official codes use skimage.io to read the cropped images from disk
    # instead of directly using the intermediate results in the memory (as we
    # do). Such a different operation brings slight differences due to
    # skimage.io. For aligning with the official results, we could set the
    # official_adaption to True.
    parser.add_argument('--official_adaption', type=bool, default=True)

    # The following are the paths for dlib models
    parser.add_argument(
        '--detection_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat'  # noqa: E501
    )
    parser.add_argument(
        '--landmark5_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat'  # noqa: E501
    )
    parser.add_argument(
        '--landmark68_path',
        type=str,
        default=  # noqa: E251
        'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat'  # noqa: E501
    )

    args = parser.parse_args()
    if args.test_path.endswith('/'):  # solve when path ends with /
        args.test_path = args.test_path[:-1]
    result_root = f'results/DFDNet/{os.path.basename(args.test_path)}'

    # set up the DFDNet
    net = DFDNet(64, dict_path=args.dict_path).to(device)
    checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
    net.load_state_dict(checkpoint['params'])
    net.eval()

    save_crop_root = os.path.join(result_root, 'cropped_faces')
    save_inverse_affine_root = os.path.join(result_root, 'inverse_affine')
    os.makedirs(save_inverse_affine_root, exist_ok=True)
    save_restore_root = os.path.join(result_root, 'restored_faces')
    save_final_root = os.path.join(result_root, 'final_results')

    face_helper = FaceRestoreHelper(args.upscale_factor, face_size=512)

    # scan all the jpg and png images
    for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
        img_name = os.path.basename(img_path)
        print(f'Processing {img_name} image ...')
        save_crop_path = os.path.join(save_crop_root, img_name)
        if args.save_inverse_affine:
            save_inverse_affine_path = os.path.join(save_inverse_affine_root, img_name)
        else:
            save_inverse_affine_path = None

        face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
        # detect faces
        num_det_faces = face_helper.detect_faces(
            img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
        # get 5 face landmarks for each face
        num_landmarks = face_helper.get_face_landmarks_5()
        print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
        # warp and crop each face
        face_helper.warp_crop_faces(save_crop_path, save_inverse_affine_path)

        if args.official_adaption:
            path, ext = os.path.splitext(save_crop_path)
            paths = sorted(glob.glob(f'{path}_[0-9]*.png'))
            cropped_faces = [io.imread(path) for path in paths]
        else:
            cropped_faces = face_helper.cropped_faces

        # get 68 landmarks for each cropped face
        num_landmarks = face_helper.get_face_landmarks_68()
        print(f'\tDetect {num_landmarks} faces for 68 landmarks.')

        face_helper.free_dlib_gpu_memory()

        print('\tFace restoration ...')
        # face restoration for each cropped face
        assert len(cropped_faces) == len(face_helper.all_landmarks_68)
        for idx, (cropped_face, landmarks) in enumerate(zip(cropped_faces, face_helper.all_landmarks_68)):
            if landmarks is None:
                print(f'Landmarks is None, skip cropped faces with idx {idx}.')
                # just copy the cropped faces to the restored faces
                restored_face = cropped_face
            else:
                # prepare data
                part_locations = get_part_location(landmarks)
                cropped_face = transforms.ToTensor()(cropped_face)
                cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
                cropped_face = cropped_face.unsqueeze(0).to(device)

                try:
                    with torch.no_grad():
                        output = net(cropped_face, part_locations)
                        restored_face = tensor2img(output, min_max=(-1, 1))
                    del output
                    torch.cuda.empty_cache()
                except Exception as e:
                    print(f'DFDNet inference fail: {e}')
                    restored_face = tensor2img(cropped_face, min_max=(-1, 1))

            path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
            save_path = f'{path}_{idx:02d}.png'
            imwrite(restored_face, save_path)
            face_helper.add_restored_face(restored_face)

        print('\tGenerate the final result ...')
        # paste each restored face to the input image
        face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name))

        # clean all the intermediate results to process the next image
        face_helper.clean_all()

    print(f'\nAll results are saved in {result_root}')
