import math

import numpy as np

from utils.transforms import transform_preds
import kornia
from kornia.geometry.transform import warp_perspective
import torch


def augment_with_matrix(aug_matrix, input_tensor):
    '''Intert augmentation
    Input:
        aug_matrix: tensor of transformation matrix of shape (B, 2, 3) or (B, 3, 3)
        input_tensor: input tensor of shape (B, C, H, W)
    Return:
        warped: tensor of shape (B, C, H, W)
    '''
    bs, c, h, w = input_tensor.shape
    if aug_matrix.shape[1] == 2:
        mode = 'affine'
    elif aug_matrix.shape[1] == 3:
        mode = 'perspective'
    else:
        ValueError('Incorrect shape of augmentation matrix')

    if c == 3:
        #This is a color image
        if mode == 'affine':
            warped = kornia.warp_affine(input_tensor, aug_matrix, dsize=(h, w))
        if mode == 'perspective':
            warped = warp_perspective(input_tensor, aug_matrix, dsize=(h, w))
    else:
        #This is heatmaps
        ts_split = input_tensor.split(split_size=1, dim=1)
        if mode == 'affine':
            warped = torch.cat([kornia.warp_affine(ts_split[i], aug_matrix, dsize=(h, w)) for i in range(len(ts_split))],
                                dim=1)
        if mode == 'perspective':
            warped = torch.cat([warp_perspective(ts_split[i], aug_matrix, dsize=(h, w)) for i in range(len(ts_split))],
                                dim=1)
    return warped


def get_max_preds(batch_heatmaps):
    '''
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    '''
    assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)
    maxvals = np.amax(heatmaps_reshaped, 2)

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals


def get_final_preds(config, batch_heatmaps, center, scale):
    coords, maxvals = get_max_preds(batch_heatmaps)

    heatmap_height = batch_heatmaps.shape[2]
    heatmap_width = batch_heatmaps.shape[3]

    # post-processing
    if config.TEST.POST_PROCESS:
        for n in range(coords.shape[0]):
            for p in range(coords.shape[1]):
                hm = batch_heatmaps[n][p]
                px = int(math.floor(coords[n][p][0] + 0.5))
                py = int(math.floor(coords[n][p][1] + 0.5))
                if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
                    diff = np.array(
                        [
                            hm[py][px+1] - hm[py][px-1],
                            hm[py+1][px]-hm[py-1][px]
                        ]
                    )
                    coords[n][p] += np.sign(diff) * .25

    preds = coords.copy()

    # Transform back
    if config.DATASET.CENTER_SCALE:
        for i in range(coords.shape[0]):
            preds[i] = transform_preds(
                coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
            )

    return preds, maxvals
