
import math
import logging

import numpy as np
import torchvision
import cv2


logger = logging.getLogger(__name__)


KEYPOINT_COCO = {
    0: ["nose",           (255, 204, 0)], # Bright yellow/gold
    1: ["left_eye",       (0, 153, 255)], # bright blue
    2: ["right_eye",      (255, 102, 102)], # light red/salmon
    3: ["left_ear",       (0, 102, 204)], # slightly darker blue
    4: ["right_ear",      (204, 51, 51)], # slightly darker red
    5: ["left_shoulder",  (0, 204, 102)], # vibrant green
    6: ["right_shoulder", (255, 153, 51)], # warm orange
    7: ["left_elbow",     (0, 255, 0)],   # Pure green
    8: ["right_elbow",    (255, 128, 0)], # Pure orange
    9: ["left_wrist",     (51, 255, 51)], # Lighter green
    10: ["right_wrist",   (255, 178, 102)], # Lighter orange
    11: ["left_hip",      (0, 153, 51)], # darker green
    12: ["right_hip",     (204, 102, 0)], # darker orange
    13: ["left_knee",     (102, 0, 204)], # rich purple
    14: ["right_knee",    (204, 0, 102)], # deep rose/magenta
    15: ["left_ankle",    (153, 51, 255)], # Lighter purple
    16: ["right_ankle",   (255, 51, 153)] # Lighter rose/magenta
}

KEYPOINT_MPII = {
    0:  ["right_ankle",    (255, 51, 153)], # Lighter rose/magenta
    1:  ["right_knee",     (204, 0, 102)], # deep rose/magenta
    2:  ["right_hip",      (204, 102, 0)], # darker orange
    3:  ["left_hip",       (0, 153, 51)], # darker green
    4:  ["left_knee",      (102, 0, 204)], # rich purple
    5:  ["left_ankle",     (153, 51, 255)], # Lighter purple
    6:  ["pelvis",         (204, 51, 51)], # slightly darker red
    7:  ["thorax",         (0, 102, 204)], # slightly darker blue
    8:  ["upper neck",     (0, 153, 255)], # bright blue
    9:  ["head top",       (255, 204, 0)], # Bright yellow/gold
    10: ["right_wrist",    (255, 178, 102)], # Lighter orange
    11: ["right_elbow",    (255, 128, 0)], # Pure orange
    12: ["right_shoulder", (255, 153, 51)], # warm orange
    13: ["left_shoulder",  (0, 204, 102)], # vibrant green
    14: ["left_elbow",     (0, 255, 0)],   # Pure green
    15: ["left_wrist",     (51, 255, 51)] # Lighter green
}

SKELETON_COCO = {
    (15, 13): (120, 51, 204),  # Knee-Ankle (Left), Medium purple
    (13, 11): (80, 0, 180),    # Hip-Knee (Left), Darker purple
    (16, 14): (255, 51, 120),  # Knee-Ankle (Right), Medium rose/magenta
    (14, 12): (180, 0, 80),    # Hip-Knee (Right), Darker rose/magenta
    (11, 12): (120, 120, 120), # Left Hip-Right Hip, neutral gray
    (5, 7):   (0, 153, 0),     # Shoulder-Elbow (Left), Darker green
    (6, 8):   (204, 102, 0),   # Shoulder-Elbow (Right), Darker orange
    (7, 9):   (51, 204, 51),   # Elbow-Wrist (Left), Medium green
    (8, 10):  (255, 153, 51),  # Elbow-Wrist (Right), Medium orange
    (0, 1):   (100, 180, 255), # Nose-Left Eye, Light blue
    (0, 2):   (255, 180, 180), # Nose-Right Eye, Light red
    (1, 3):   (50, 150, 220),  # Left Eye-Left Ear, Medium blue 
    (2, 4):   (220, 150, 50),  # Right Eye-Right Ear, Medium red/orange
    (3, 5):   (50, 150, 220),  # Left Ear-Left shoulder, Medium blue
    (4, 6):   (220, 150, 50),  # Right Ear-Right shoulder, Medium red/orange
    (5, 6):   (150, 150, 150), # Left Shoulder-Right Shoulder, Neutral gray
    (11, 5):  (50, 180, 70),   # Left Shoulder-Left Hip, Greenish-gray
    (12, 6):  (180, 100, 50),  # Right Shoulder-Right Hip, Orangeish-gray
}

SKELETON_MPII = {
    (4, 5):   (120, 51, 204),  # Knee-Ankle (Left), Medium purple
    (3, 4):   (80, 0, 180),    # Hip-Knee (Left), Darker purple
    (0, 1):   (255, 51, 120),  # Knee-Ankle (Right), Medium rose/magenta
    (1, 2):   (180, 0, 80),    # Hip-Knee (Right), Darker rose/magenta
    (2, 6):   (180, 100, 50),  # Right Hip-Pelvis, Orangeish-gray
    (3, 6):   (50, 180, 70),   # Left Hip-Pelvis, Greenish-gray
    (6, 7):   (120, 120, 120), # Pelvis-Thorax, neutral gray
    (13, 14): (0, 153, 0),     # Shoulder-Elbow (Left), Darker green
    (12, 11): (204, 102, 0),   # Shoulder-Elbow (Right), Darker orange
    (14, 15): (51, 204, 51),   # Elbow-Wrist (Left), Medium green
    (11, 10): (255, 153, 51),  # Elbow-Wrist (Right), Medium orange
    (8, 9):   (100, 180, 255), # upper neck-head top, Light blue
    (7, 8):   (255, 180, 180), # thorax-upper neck, Light red
    (7, 13):  (50, 150, 220),  # thorax-left_shoulder, Medium blue 
    (7, 12):  (220, 150, 50)   # thorax-right_shoulder, Medium red/orange
}


def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
                                 file_name, nrow=8, padding=2):
    
    batch_image = batch_image[:, [2, 1, 0], :, :] # rgb -> bgr
    grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    ndarr = ndarr.copy()

    nmaps = batch_image.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height = int(batch_image.size(2) + padding)
    width = int(batch_image.size(3) + padding)
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            joints = batch_joints[k]
            joints_vis = batch_joints_vis[k]

            for joint, joint_vis in zip(joints, joints_vis):
                joint[0] = x * width + padding + joint[0]
                joint[1] = y * height + padding + joint[1]
                if joint_vis[0]:
                    cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 1, [0, 255, 255], -1)
            k = k + 1
    cv2.imwrite(file_name, ndarr)


def save_batch_image_with_skeleton(config, batch_image, batch_joints, batch_joints_vis,
                                   file_name, nrow=8, padding=2):
    
    batch_image = batch_image[:, [2, 1, 0], :, :] # rgb -> bgr
    grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    ndarr = ndarr.copy()

    nmaps = batch_image.size(0)
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height = int(batch_image.size(2) + padding)
    width = int(batch_image.size(3) + padding)
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            joints = batch_joints[k]
            joints_vis = batch_joints_vis[k]

            if config.DATASET.DATASET == 'coco':
                SKELETON = SKELETON_COCO
            elif config.DATASET.DATASET == 'mpii':
                SKELETON = SKELETON_MPII
            else:
                assert 'Skeleton is not defined!'

            for i, (connection, color) in enumerate(SKELETON.items()): # draw skeleton
                idx1, idx2 = connection
                if joints_vis[idx1][0] > 0 and joints_vis[idx2][0] > 0:
                    x1 = x * width  + padding + joints[idx1][0]
                    y1 = y * height + padding + joints[idx1][1]
                    x2 = x * width  + padding + joints[idx2][0]
                    y2 = y * height + padding + joints[idx2][1]
                    cv2.line(ndarr, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2)

            for i, (joint, joint_vis) in enumerate(zip(joints, joints_vis)): # draw keypoints
                if joint_vis[0]:
                    x1 = x * width + padding + joint[0]
                    y1 = y * height + padding + joint[1]
                    if config.DATASET.DATASET == 'coco':
                        color = KEYPOINT_COCO[i][1]
                    elif config.DATASET.DATASET == 'mpii':
                        color = KEYPOINT_MPII[i][1]
                    else:
                        color = (0, 0, 255)
                    cv2.circle(ndarr, (int(x1), int(y1)), 2, color, -1)
            
            k = k + 1
    cv2.imwrite(file_name, ndarr)


def save_batch_embed(batch_image, batch_embed, file_name,
                        embed_height=64, embed_width=48, normalize=True):
    
    if normalize:
        batch_image = batch_image.clone()
        min = float(batch_image.min())
        max = float(batch_image.max())

        batch_image.add_(-min).div_(max - min + 1e-5)
    batch_image = batch_image[:, [2, 1, 0], :, :] # rgb -> bgr

    batch_size = batch_embed.size(0)
    batch_embed_avg = batch_embed.mean(dim=2) # average along the channel
    batch_embed_avg -= batch_embed_avg.min(1, keepdim=True)[0] # scale to [0, 1]
    batch_embed_avg /= batch_embed_avg.max(1, keepdim=True)[0]
    batch_embed_avg = batch_embed_avg.reshape([batch_size, embed_height, embed_width])
    
    # each row has 4 images and corresponding embeds
    n_col = 4
    grid_image = np.zeros((batch_size*embed_height//n_col,
                           2*n_col*embed_width, 3),
                           dtype=np.uint8)

    for i in range(batch_size):
        image = batch_image[i].mul(255).clamp(0, 255).byte()\
                              .permute(1, 2, 0)\
                              .cpu().numpy()
        resized_image = cv2.resize(image,
                                   (int(embed_width), int(embed_height)))

        embed_avg = batch_embed_avg[i].mul(255).clamp(0, 255).byte()\
                                        .cpu().numpy()
        colored_embed = cv2.applyColorMap(embed_avg, cv2.COLORMAP_JET)

        # masked_image = colored_embed*0.7 + resized_image*0.3

        height_beg = embed_height * (i // n_col)
        height_end = embed_height * (i // n_col + 1)
        
        col_ind = i % n_col
        width_beg = embed_width * (col_ind * 2)
        width_end = embed_width * (col_ind * 2 + 1)
        grid_image[height_beg:height_end, width_beg:width_end, :] = \
            resized_image
        
        width_beg = embed_width * (col_ind * 2 + 1)
        width_end = embed_width * (col_ind * 2 + 2)
        grid_image[height_beg:height_end, width_beg:width_end, :] = \
            colored_embed # masked_image

    cv2.imwrite(file_name, grid_image)


def save_debug_images(config, input, meta, joints_pred, prefix):
    if not config.DEBUG.DEBUG:
        return

    if config.DEBUG.SAVE_BATCH_IMAGES_GT:
        save_batch_image_with_joints(
            input, meta['joints'], meta['joints_vis'],
            '{}_gt.jpg'.format(prefix)
        )
    if config.DEBUG.SAVE_BATCH_IMAGES_PRED:
        save_batch_image_with_joints(
            input, joints_pred.cpu().numpy(), meta['joints_vis'],
            '{}_pred.jpg'.format(prefix)
        )


def save_pred_images(config, input, meta, joints_pred, prefix):
    if not config.DEBUG.DEBUG:
        return

    if config.DEBUG.SAVE_BATCH_IMAGES_PRED:
        save_batch_image_with_skeleton(
            config, input, joints_pred.cpu().numpy(), meta['joints_vis'],
            '{}_pred.jpg'.format(prefix)
        )

