# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F

from detectron2.config import CfgNode
from detectron2.structures import Instances

from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix

from .embed_utils import PackedCseAnnotations
from .mask import extract_data_for_mask_loss_from_matches


def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor:
    rows = torch.arange(grid_size)
    cols = torch.arange(grid_size)
    # at index `i` contains [row, col], where
    # row = i // grid_size
    # col = i % grid_size
    pix_coords = (
        torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float()
    )
    return squared_euclidean_distance_matrix(pix_coords, pix_coords)


def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
    fg_mask_flattened = fg_mask.reshape((-1,))
    num_pixels = int(fg_mask_flattened.sum().item())
    fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0]
    if (sample_size <= 0) or (num_pixels <= sample_size):
        return fg_pixel_indices
    sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size]
    return fg_pixel_indices[sample_indices]


def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
    fg_mask_flattened = fg_mask.reshape((-1,))
    num_pixels = int(fg_mask_flattened.sum().item())
    if (sample_size <= 0) or (num_pixels <= sample_size):
        return fg_mask_flattened.nonzero(as_tuple=True)[0]
    return fg_mask_flattened.float().multinomial(sample_size, replacement=False)  # pyre-ignore[16]


class PixToShapeCycleLoss(nn.Module):
    """
    Cycle loss for pixel-vertex correspondence
    """

    def __init__(self, cfg: CfgNode):
        super().__init__()
        self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
        self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
        self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P
        self.use_all_meshes_not_gt_only = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY
        )
        self.num_pixels_to_sample = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE
        )
        self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA
        self.temperature_pix_to_vertex = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX
        )
        self.temperature_vertex_to_pix = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL
        )
        self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE)

    def forward(
        self,
        proposals_with_gt: List[Instances],
        densepose_predictor_outputs: Any,
        packed_annotations: PackedCseAnnotations,
        embedder: nn.Module,
    ):
        """
        Args:
            proposals_with_gt (list of Instances): detections with associated
                ground truth data; each item corresponds to instances detected
                on 1 image; the number of items corresponds to the number of
                images in a batch
            densepose_predictor_outputs: an object of a dataclass that contains predictor
                outputs with estimated values; assumed to have the following attributes:
                * embedding - embedding estimates, tensor of shape [N, D, S, S], where
                  N = number of instances (= sum N_i, where N_i is the number of
                      instances on image i)
                  D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
                  S = output size (width and height)
            packed_annotations (PackedCseAnnotations): contains various data useful
                for loss computation, each data is packed into a single tensor
            embedder (nn.Module): module that computes vertex embeddings for different meshes
        """
        pix_embeds = densepose_predictor_outputs.embedding
        if self.pixel_dists.device != pix_embeds.device:
            # should normally be done only once
            self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device)
        with torch.no_grad():
            mask_loss_data = extract_data_for_mask_loss_from_matches(
                proposals_with_gt, densepose_predictor_outputs.coarse_segm
            )
        # GT masks - tensor of shape [N, S, S] of int64
        masks_gt = mask_loss_data.masks_gt.long()  # pyre-ignore[16]
        assert len(pix_embeds) == len(masks_gt), (
            f"Number of instances with embeddings {len(pix_embeds)} != "
            f"number of instances with GT masks {len(masks_gt)}"
        )
        losses = []
        mesh_names = (
            self.shape_names
            if self.use_all_meshes_not_gt_only
            else [
                MeshCatalog.get_mesh_name(mesh_id.item())
                for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique()  # pyre-ignore[16]
            ]
        )
        for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt):
            # pixel_embeddings [D, S, S]
            # mask_gt [S, S]
            for mesh_name in mesh_names:
                mesh_vertex_embeddings = embedder(mesh_name)
                # pixel indices [M]
                pixel_indices_flattened = _sample_fg_pixels_randperm(
                    mask_gt, self.num_pixels_to_sample
                )
                # pixel distances [M, M]
                pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[
                    torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened)
                ]
                # pixel embeddings [M, D]
                pixel_embeddings_sampled = normalize_embeddings(
                    pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T
                )
                # pixel-vertex similarity [M, K]
                sim_matrix = pixel_embeddings_sampled.mm(  # pyre-ignore[16]
                    mesh_vertex_embeddings.T
                )
                c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1)
                c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1)
                c_cycle = c_pix_vertex.mm(c_vertex_pix)
                loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p)
                losses.append(loss_cycle)

        if len(losses) == 0:
            return pix_embeds.sum() * 0
        return torch.stack(losses, dim=0).mean()

    def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module):
        losses = [
            embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names  # pyre-ignore[29]
        ]
        losses.append(densepose_predictor_outputs.embedding.sum() * 0)
        return torch.mean(torch.stack(losses))
