import torch
import numpy as np
import cv2
import torchvision.utils as ttf
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from fastai.basics import *


transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)]
)

pil2tensor = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)]
)

mean_box_lm4p_512 = np.array(
    [
        [80, 80],
        [80, 432],
        [432, 432],
        [432, 80],
    ],
    dtype=np.float32,
)

mean_face_lm5p_256 = np.array(
    [
        [(30.2946 + 8) * 2 + 16, 51.6963 * 2],  # left eye pupil
        [(65.5318 + 8) * 2 + 16, 51.5014 * 2],  # right eye pupil
        [(48.0252 + 8) * 2 + 16, 71.7366 * 2],  # nose tip
        [(33.5493 + 8) * 2 + 16, 92.3655 * 2],  # left mouth corner
        [(62.7299 + 8) * 2 + 16, 92.2041 * 2],  # right mouth corner
    ],
    dtype=np.float32,
)


def get_box_lm4p(pts):
    x1 = np.min(pts[:,0])
    x2 = np.max(pts[:,0])
    y1 = np.min(pts[:,1])
    y2 = np.max(pts[:,1])
    
    x_center = (x1+x2)*0.5
    y_center = (y1+y2)*0.5
    box_size = max(x2-x1, y2-y1)
    
    x1 = x_center-0.5*box_size
    x2 = x_center+0.5*box_size
    y1 = y_center-0.5*box_size
    y2 = y_center+0.5*box_size

    return np.array([[x1, y1], [x1, y2], [x2, y2], [x2, y1]], dtype=np.float32)


def get_affine_transform(target_face_lm5p, mean_lm5p):
    mat_warp = np.zeros((2,3))
    A = np.zeros((4,4))
    B = np.zeros((4))
    for i in range(5):
        #sa[0][0] += a[i].x*a[i].x + a[i].y*a[i].y;
        A[0][0] += target_face_lm5p[i][0] * target_face_lm5p[i][0] + target_face_lm5p[i][1] * target_face_lm5p[i][1]
        #sa[0][2] += a[i].x;
        A[0][2] += target_face_lm5p[i][0]
        #sa[0][3] += a[i].y;
        A[0][3] += target_face_lm5p[i][1]

        #sb[0] += a[i].x*b[i].x + a[i].y*b[i].y;
        B[0] += target_face_lm5p[i][0] * mean_lm5p[i][0] + target_face_lm5p[i][1] * mean_lm5p[i][1]
        #sb[1] += a[i].x*b[i].y - a[i].y*b[i].x;
        B[1] += target_face_lm5p[i][0] * mean_lm5p[i][1] - target_face_lm5p[i][1] * mean_lm5p[i][0]
        #sb[2] += b[i].x;
        B[2] += mean_lm5p[i][0]
        #sb[3] += b[i].y;
        B[3] += mean_lm5p[i][1]

    #sa[1][1] = sa[0][0];
    A[1][1] = A[0][0]
    #sa[2][1] = sa[1][2] = -sa[0][3];
    A[2][1] = A[1][2] = -A[0][3]
    #sa[3][1] = sa[1][3] = sa[2][0] = sa[0][2];
    A[3][1] = A[1][3] = A[2][0] = A[0][2]
    #sa[2][2] = sa[3][3] = count;
    A[2][2] = A[3][3] = 5
    #sa[3][0] = sa[0][3];
    A[3][0] = A[0][3]

    _, mat23 = cv2.solve(A, B, flags=cv2.DECOMP_SVD)
    mat_warp[0][0] = mat23[0]
    mat_warp[1][1] = mat23[0]
    mat_warp[0][1] = -mat23[1]
    mat_warp[1][0] = mat23[1]
    mat_warp[0][2] = mat23[2]
    mat_warp[1][2] = mat23[3]

    return mat_warp


def transformation_from_points(points1, points2):
    points1 = np.float64(np.matrix([[point[0], point[1]] for point in points1]))
    points2 = np.float64(np.matrix([[point[0], point[1]] for point in points2]))

    points1 = points1.astype(np.float64)
    points2 = points2.astype(np.float64)
    c1 = np.mean(points1, axis=0)
    c2 = np.mean(points2, axis=0)
    points1 -= c1
    points2 -= c2
    s1 = np.std(points1)
    s2 = np.std(points2)
    points1 /= s1
    points2 /= s2
    #points2 = np.array(points2)
    #write_pts('pt2.txt', points2)
    U, S, Vt = np.linalg.svd(points1.T * points2)
    R = (U * Vt).T
    return np.array(np.vstack([np.hstack(((s2 / s1) * R,c2.T - (s2 / s1) * R * c1.T)),np.matrix([0., 0., 1.])])[:2])


def convert_batch_to_nprgb(batch, nrow):
    grid_tensor = ttf.make_grid(batch * 0.5 + 0.5, nrow=nrow)
    im_rgb = (255 * grid_tensor.permute(1, 2, 0).cpu().numpy()).astype("uint8")
    return im_rgb


def draw_pts3_batch(
    pts3, gaze, warp_mat256_np, dst_size, im_list=None, return_pt=False
):
    with torch.no_grad():
        landmarks = pts3.cpu().numpy().round().astype(int)

    colors = plt.get_cmap("rainbow")(np.linspace(0, 1, landmarks.shape[1]))
    colors = (255 * colors).astype(int)[:, 0:3].tolist()

    im_pts70_list = []
    if im_list is None:
        im_list = [
            np.zeros((256, 256, 3), dtype=np.uint8) for idx in range(landmarks.shape[0])
        ]
    else:
        im_list = [np.array(x) for x in im_list]
    for idx in range(landmarks.shape[0]):
        image = im_list[idx]

        for i in range(landmarks.shape[1]):
            x, y = landmarks[idx, i, :]
            color = colors[i]
            image = cv2.circle(
                image,
                (x, y),
                radius=2,
                color=(color[2], color[1], color[0]),
                thickness=-1,
            )

        dst_image = cv2.warpAffine(
            image,
            warp_mat256_np[idx],
            (dst_size, dst_size),
            flags=(cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP),
            borderMode=cv2.BORDER_CONSTANT,
        )
        im_pts70_list.append(Image.fromarray(dst_image))

    if return_pt:
        tensor_list = [
            transform(x).view(1, 3, dst_size, dst_size) for x in im_pts70_list
        ]
        batch = torch.cat(tensor_list, dim=0)
        return batch
    else:
        return im_pts70_list


def draw_pts70_batch(
    pts68, gaze, warp_mat256_np, dst_size, im_list=None, return_pt=False
):

    left_eye1 = pts68[:, 36]
    left_eye2 = pts68[:, 39]
    right_eye1 = pts68[:, 42]
    right_eye2 = pts68[:, 45]

    right_eye_length = torch.sqrt(
        torch.sum((right_eye2 - right_eye1) ** 2, dim=1, keepdim=True)
    )
    left_eye_length = torch.sqrt(
        torch.sum((left_eye2 - left_eye1) ** 2, dim=1, keepdim=True)
    )
    right_eye_center = (right_eye2 + right_eye1) * 0.5
    left_eye_center = (left_eye2 + left_eye1) * 0.5

    with torch.no_grad():
        left_gaze = gaze[:, :2] * left_eye_length + left_eye_center
        right_gaze = gaze[:, 2:] * right_eye_length + right_eye_center
        pts70 = torch.cat(
            [pts68, left_gaze.view(-1, 1, 2), right_gaze.view(-1, 1, 2)], dim=1
        )
        landmarks = pts70.cpu().numpy().round().astype(int)

    colors = plt.get_cmap("rainbow")(np.linspace(0, 1, landmarks.shape[1]))
    colors = (255 * colors).astype(int)[:, 0:3].tolist()

    im_pts70_list = []
    if im_list is None:
        im_list = [
            np.zeros((256, 256, 3), dtype=np.uint8) for idx in range(landmarks.shape[0])
        ]
    else:
        im_list = [np.array(x) for x in im_list]
    for idx in range(landmarks.shape[0]):
        image = im_list[idx]

        for i in range(landmarks.shape[1]):
            x, y = landmarks[idx, i, :]
            color = colors[i]
            image = cv2.circle(
                image,
                (x, y),
                radius=2,
                color=(color[2], color[1], color[0]),
                thickness=-1,
            )

        dst_image = cv2.warpAffine(
            image,
            warp_mat256_np[idx],
            (dst_size, dst_size),
            flags=(cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP),
            borderMode=cv2.BORDER_CONSTANT,
        )
        im_pts70_list.append(Image.fromarray(dst_image))

    if return_pt:
        tensor_list = [
            transform(x).view(1, 3, dst_size, dst_size) for x in im_pts70_list
        ]
        batch = torch.cat(tensor_list, dim=0)
        return batch
    else:
        return im_pts70_list


def draw_pts70_batch_obvious(
    pts68, gaze, warp_mat256_np, dst_size, im_list=None, return_pt=False
):

    left_eye1 = pts68[:, 36]
    left_eye2 = pts68[:, 39]
    right_eye1 = pts68[:, 42]
    right_eye2 = pts68[:, 45]

    right_eye_length = torch.sqrt(
        torch.sum((right_eye2 - right_eye1) ** 2, dim=1, keepdim=True)
    )
    left_eye_length = torch.sqrt(
        torch.sum((left_eye2 - left_eye1) ** 2, dim=1, keepdim=True)
    )

    right_eye_center = (right_eye2 + right_eye1) * 0.5
    left_eye_center = (left_eye2 + left_eye1) * 0.5

    left_gaze = gaze[:, :2] * left_eye_length + left_eye_center
    right_gaze = gaze[:, 2:] * right_eye_length + right_eye_center

    pts70 = torch.cat(
        [pts68, left_gaze.view(-1, 1, 2), right_gaze.view(-1, 1, 2)], dim=1
    )
    landmarks = pts70.cpu().numpy().round().astype(int)

    colors = plt.get_cmap("rainbow")(np.linspace(0, 1, landmarks.shape[1]))
    colors = (255 * colors).astype(int)[:, 0:3].tolist()

    im_pts70_list = []
    if im_list is None:
        im_list = [
            np.zeros((256, 256, 3), dtype=np.uint8) for idx in range(landmarks.shape[0])
        ]
    else:
        im_list = [np.array(x) for x in im_list]
    for idx in range(landmarks.shape[0]):
        image = im_list[idx]

        for i in range(landmarks.shape[1]):
            x, y = landmarks[idx, i, :]
            color = colors[i]
            image = cv2.circle(
                image,
                (x, y),
                radius=6,
                color=(color[2], color[1], color[0]),
                thickness=-1,
            )

        dst_image = cv2.warpAffine(
            image,
            warp_mat256_np[idx],
            (dst_size, dst_size),
            flags=(cv2.INTER_LINEAR | cv2.WARP_INVERSE_MAP),
            borderMode=cv2.BORDER_CONSTANT,
        )
        im_pts70_list.append(Image.fromarray(dst_image))

    if return_pt:
        tensor_list = [
            transform(x).view(1, 3, dst_size, dst_size) for x in im_pts70_list
        ]
        batch = torch.cat(tensor_list, dim=0)
        return batch
    else:
        return im_pts70_list
