import torch
import torch.nn as nn
from torch.nn import functional as F


def get_coord(x, other_axis, axis_size):
    "get x-y coordinates"
    g_c_prob = torch.mean(x, dim=other_axis)  # B,NMAP,W
    g_c_prob = F.softmax(g_c_prob, dim=2) # B,NMAP,W
    coord_pt = torch.linspace(-1.0, 1.0, axis_size).to(x.device) # W
    coord_pt = coord_pt.view(1, 1, axis_size) # 1,1,W
    g_c = torch.sum(g_c_prob * coord_pt, dim=2) # B,NMAP
    return g_c, g_c_prob


def get_gaussian_maps(mu, shape_hw, inv_std, mode='rot'):
    """
    Generates [B,NMAPS,SHAPE_H,SHAPE_W] tensor of 2D gaussians,
    given the gaussian centers: MU [B, NMAPS, 2] tensor.

    STD: is the fixed standard dev.
    """
    mu_y, mu_x = mu[:, :, 0:1], mu[:, :, 1:2]

    y = torch.linspace(-1.0, 1.0, shape_hw[0]).to(mu.device)

    x = torch.linspace(-1.0, 1.0, shape_hw[1]).to(mu.device)

    if mode in ['rot', 'flat']:
        mu_y, mu_x = torch.unsqueeze(mu_y, dim=-1), torch.unsqueeze(mu_x, dim=-1)

        y = y.view(1, 1, shape_hw[0], 1)
        x = x.view(1, 1, 1, shape_hw[1])

        g_y = (y - mu_y)**2
        g_x = (x - mu_x)**2
        dist = (g_y + g_x) * inv_std**2

        if mode == 'rot':
            g_yx = torch.exp(-dist)
        else:
            g_yx = torch.exp(-torch.pow(dist + 1e-5, 0.25))

    elif mode == 'ankush':
        y = y.view(1, 1, shape_hw[0])
        x = x.view(1, 1, shape_hw[1])

        g_y = torch.exp(-torch.sqrt(1e-4 + torch.abs((mu_y - y) * inv_std)))
        g_x = torch.exp(-torch.sqrt(1e-4 + torch.abs((mu_x - x) * inv_std)))

        g_y = torch.unsqueeze(g_y, dim=3)
        g_x = torch.unsqueeze(g_x, dim=2)
        g_yx = torch.matmul(g_y, g_x)  # [B, NMAPS, H, W]

    else:
        raise ValueError('Unknown mode: ' + str(mode))

    return g_yx


class KeypointSimilarityLoss(nn.Module):
    def __init__(self, num_kp, tune_hm=True, gauss_std=0.1, gauss_mode='ankush', use_gpu=True):
        super(KeypointSimilarityLoss, self).__init__()
        self.gauss_std = gauss_std
        self.gauss_mode = gauss_mode
        self.tune_hm = tune_hm
        self.num_kp = num_kp
        self.use_gpu = use_gpu
        self.similarity_criterion = TripletLoss()

    def hm_tuner(self, x):
        #print('HM Tuner input shape', x.shape)
        xshape = x.shape

        gauss_y, gauss_y_prob = get_coord(x, 3, xshape[2])  # B,NMAP
        gauss_x, gauss_x_prob = get_coord(x, 2, xshape[3])  # B,NMAP
        gauss_mu = torch.stack([gauss_y, gauss_x], dim=2)

        gauss_hm = get_gaussian_maps(gauss_mu, [64, 64], 1.0 / self.gauss_std, mode=self.gauss_mode)
        #print('HM Tuner output shape', gauss_hm.shape)

        return gauss_hm

    def forward(self, image_feat, hm):
        '''
        x - tensor of shape
        '''
        bs = image_feat.size(0)
        if self.tune_hm:
            hm = self.hm_tuner(hm)

        feats = torch.cat([torch.mul(image_feat, hm_) for hm_ in torch.split(hm, 1, 1)],dim=0)
        feats = F.max_pool2d(feats, kernel_size=feats.size()[2:])
        feats = feats.view(feats.shape[0], -1)

        sem_kp = list(range(self.num_kp))
        gt_labels = torch.tensor(bs*sem_kp).view(bs, -1).transpose(1,0).reshape(bs*len(sem_kp),)
        if self.use_gpu:
            gt_labels = gt_labels.cuda()
        #print(feats.shape, gt_labels.shape)
        loss = self.similarity_criterion(feats, gt_labels)[0]
        return loss



class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='mean')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight=None):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss += 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                )
            else:
                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)

        return loss / num_joints


class JointsOHKMMSELoss(nn.Module):
    def __init__(self, use_target_weight, topk=8):
        super(JointsOHKMMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight
        self.topk = topk

    def ohkm(self, loss):
        ohkm_loss = 0.
        for i in range(loss.size()[0]):
            sub_loss = loss[i]
            topk_val, topk_idx = torch.topk(
                sub_loss, k=self.topk, dim=0, sorted=False
            )
            tmp_loss = torch.gather(sub_loss, 0, topk_idx)
            ohkm_loss += torch.sum(tmp_loss) / self.topk
        ohkm_loss /= loss.size()[0]
        return ohkm_loss

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)

        loss = []
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss.append(0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                ))
            else:
                loss.append(
                    0.5 * self.criterion(heatmap_pred, heatmap_gt)
                )

        loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
        loss = torch.cat(loss, dim=1)

        return self.ohkm(loss)


def normalize(x, axis=-1):
    """Normalizing to unit length along the specified dimension.
    Args:
      x: pytorch Variable
    Returns:
      x: pytorch Variable, same shape as input
    """
    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
    return x
