# ------------------------------------------------------------------------------
# pose.pytorch
# Copyright (c) 2018-present Microsoft
# Licensed under The Apache-2.0 License [see LICENSE for details]
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import numpy as np
from scipy.special import softmax

from utils.transforms import transform_preds


def get_max_preds(batch_heatmaps):
    '''
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    '''
    assert isinstance(batch_heatmaps, np.ndarray), \
        'batch_heatmaps should be numpy.ndarray'
    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)
    maxvals = np.amax(heatmaps_reshaped, 2)

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals


def get_final_preds(config, batch_heatmaps, center, scale):
    coords, maxvals = get_max_preds(batch_heatmaps)

    heatmap_height = batch_heatmaps.shape[2]
    heatmap_width = batch_heatmaps.shape[3]

    # post-processing
    if config.TEST.POST_PROCESS:
        for n in range(coords.shape[0]):
            for p in range(coords.shape[1]):
                hm = batch_heatmaps[n][p]
                px = int(math.floor(coords[n][p][0] + 0.5))
                py = int(math.floor(coords[n][p][1] + 0.5))
                if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
                    diff = np.array([hm[py][px+1] - hm[py][px-1],
                                     hm[py+1][px]-hm[py-1][px]])
                    coords[n][p] += np.sign(diff) * .25

    preds = coords.copy()

    # re-org center and scale
    # ceter_flat = []
    # scale_flat = []
    # nview = len(center)
    # batch = len(center[0])
    # for nv in range(nview):
    #     for b in range(batch):
    #         ceter_flat.append(center[nv][b].cpu().numpy())
    #         scale_flat.append(scale[nv][b].cpu().numpy())

    if isinstance(center, (list, tuple)):
        pass
    else:
        if not isinstance(center,np.ndarray):
            center = center.cpu().numpy()
            scale = scale.cpu().numpy()
    # Transform back
    for i in range(coords.shape[0]):
        preds[i] = transform_preds(coords[i], center[i], scale[i],
                                   [heatmap_width, heatmap_height])

    return preds, maxvals


def gumbel_softmax(x, axis=None, t=1.0):
    """
    axis should be the last dim
    :param x:
    :param axis:
    :param t:
    :return:
    """
    origin_shape = x.shape
    if axis is None:
        axis = len(x.shape) - 1
    new_shape = list(x.shape[:axis])
    new_shape.append(-1)
    x = x.reshape(new_shape)
    v_exp = np.exp(x/t)
    axis_sum = np.sum(v_exp, axis=axis, keepdims=True)+1e-6
    g_softmaxed = v_exp/axis_sum
    return g_softmaxed.reshape(origin_shape)



def get_final_preds_by_softmaxed_aggregation(config, batch_heatmaps, center, scale, temperature=1.0):
    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    height = batch_heatmaps.shape[2]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    maxvals = np.amax(heatmaps_reshaped, 2)
    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    # sm_vals = softmax(heatmaps_reshaped, axis=2)
    sm_vals = gumbel_softmax(heatmaps_reshaped, axis=2, t=temperature)
    sm_vals = sm_vals.reshape((batch_size, num_joints, height, width))

    hs = np.linspace(0, height-1, height).reshape((height, 1))
    yvals = sm_vals*hs
    y = yvals.reshape((batch_size, num_joints, -1)).sum(axis=2)

    ws = np.linspace(0, width - 1, width).reshape((1, width))
    xvals = sm_vals * ws
    x = xvals.reshape((batch_size, num_joints, -1)).sum(axis=2)

    heatmap_height = batch_heatmaps.shape[2]
    heatmap_width = batch_heatmaps.shape[3]
    coords = np.stack([x,y], axis=2)
    preds = np.zeros_like(coords)

    for i in range(coords.shape[0]):
        preds[i] = transform_preds(coords[i], center[i], scale[i],
                                   [heatmap_width, heatmap_height])

    return preds, maxvals