import math

import numpy as np
import torch

from unrealpose.utils.transforms import transform_preds


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def my_transform(imgs, mean, std):
    """
    Rewrite the pytorch transforms to include batch processing
    imgs: numpy array [N, H, W, C]
    Return:
        imgs: torch.tensor [N, C, H, W]
    """

    imgs = torch.from_numpy(imgs.transpose((0, 3, 1, 2)))
    imgs = imgs.float().div(255)

    dtype = imgs.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=imgs.device)
    std = torch.as_tensor(std, dtype=dtype, device=imgs.device)
    if (std == 0).any():
        raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
    if mean.ndim == 1:
        mean = mean[None, :, None, None]
    if std.ndim == 1:
        std = std[None, :, None, None]
    imgs.sub_(mean).div_(std)
    return imgs


# def my_transform_numpy(imgs, mean, std):
#     """
#     Rewrite the pytorch transforms to include batch processing
#     imgs: numpy array [N, H, W, C]
#     mean: a list of len 3
#     std: a list of len 3
#     Return:
#         imgs: numpy [N, C, H, W]
#     """

#     imgs = imgs.transpose((0, 3, 1, 2))
#     imgs = imgs.astype(np.float32) / 255

#     mean = np.array(mean, dtype=np.float32)
#     std = np.array(std, dtype=np.float32)

#     mean = mean[None, :, None, None]
#     std = std[None, :, None, None]
#     imgs = (imgs - mean) / std
#     return imgs


def my_transform_np_outer(mean, std):
    """
    mean: a list of len 3
    std: a list of len 3
    """
    mean = np.array(mean, dtype=np.float32)[None, :, None, None]
    std = np.array(std, dtype=np.float32)[None, :, None, None]
    def my_transform_np(imgs):
        """
        Rewrite the pytorch transforms to include batch processing
        imgs: numpy array [N, H, W, C]
        Return:
            imgs: numpy [N, C, H, W]
        """
        imgs = imgs.transpose((0, 3, 1, 2))
        imgs = imgs.astype(np.float32) / 255
        imgs = (imgs - mean) / std
        return imgs
    return my_transform_np

def get_max_preds(batch_heatmaps):
    '''
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    '''
    # assert isinstance(batch_heatmaps, np.ndarray), \
    #     'batch_heatmaps should be numpy.ndarray'
    # assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)
    maxvals = np.amax(heatmaps_reshaped, 2)

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals

def get_final_preds(batch_heatmaps, center, scale, rot):
    coords, maxvals = get_max_preds(batch_heatmaps)

    heatmap_height = batch_heatmaps.shape[2]
    heatmap_width = batch_heatmaps.shape[3]

    # coords: [batch, j, 2]
    preds = np.empty_like(coords)

    # Transform back
    for i in range(coords.shape[0]):
        preds[i] = transform_preds(coords[i], center[i], scale[i], rot[i],
                                   [heatmap_width, heatmap_height])

    return preds, maxvals


# def get_bbox_from_kp2d_batch(kp2d, margin=0):
#     """
#     kp2d: [J, 2] or [C, J, 2]
#     Return:
#         a list of [[x, y, w, h], [x, y, w, h] ...]
#     """
#     res = []

#     if kp2d.ndim == 2:
#         kp2d = kp2d[None, ...]

#     if kp2d.ndim == 3:
#         for kp2d_arr in kp2d:
#             xmin, xmax = min(kp2d_arr[:, 0]), max(kp2d_arr[:, 0])
#             ymin, ymax = min(kp2d_arr[:, 1]), max(kp2d_arr[:, 1])
#             w = xmax - xmin
#             h = ymax - ymin
#             bb = [xmin - w * margin, ymin - h * margin, w * (1 + 2 * margin), h * (1 + 2 * margin)]
#             bb = [int(v) for v in bb]
#             res.append(bb)
#     return res


def get_bbox_from_kp2d_batch(kp2d, margin=0):
    """
    kp2d: [J, 2] or [N, J, 2]
    Return:
        bb_array: [N, 4], [x1, y1, w, h]
    """
    assert 2 <= kp2d.ndim <= 3

    if kp2d.ndim == 2:
        kp2d = kp2d[None, ...]

    x1y1 = np.amin(kp2d, axis=1)  # [N, 2]
    x2y2 = np.amax(kp2d, axis=1)  # [N, 2]
    wh = x2y2 - x1y1  # [N, 2]

    x1y1 = x1y1 - wh * margin
    wh = wh * (1 + 2 * margin)
    bb_array = np.concatenate([x1y1, wh], axis=1).astype(np.int32)  # [N, 4]
    return bb_array
