import torch
import torch.nn.functional as F
from einops import rearrange, repeat


def psnr(gt_images, pred_images):
    """Compute PSNR for a batch of images.

    Arguments:
    gt_images: torch.Tensor BxHxWxC
        Ground truth images
    pred_images: torch.Tensor BxHxWxC
        Predicted images
    """
    max_gt_values = gt_images.flatten(1).max(1)[0]
    mse = F.mse_loss(gt_images, pred_images, reduction="none").flatten(1).mean(1)
    psnr = torch.mean(20.0 * torch.log10(max_gt_values) - 10.0 * torch.log10(mse))
    return psnr


def iou(occ1: torch.Tensor, occ2: torch.Tensor) -> float:
    """Computes the Intersection over Union (IoU) value for two sets of occupancy values.

    NOTE: ASSUMES THAT OCCUPANCY VALUES ARE IN THE RANGE [-inf, inf].

    The formula used is the following:

    .. math::
        \\text{IoU} = \\frac{|A \\cap B|}{|A \\cup B|}

    :param occ1: first set of occupancy values
    :type occ1: torch.Tensor
    :param occ2: second set of occupancy values
    :type occ2: torch.Tensor

    :return: IoU value
    :rtype: float
    """
    # Put all data in second dimension
    # Also works for 1-dimensional data
    if occ1.ndim >= 2:
        occ1 = occ1.flatten(1)
    if occ2.ndim >= 2:
        occ2 = occ2.flatten(1)

    # Convert to boolean values
    occ1 = occ1 >= 0.0
    occ2 = occ2 >= 0.0

    # Compute IOU
    area_union = (occ1 | occ2).float().sum(dim=-1)
    area_intersect = (occ1 & occ2).float().sum(dim=-1)

    iou = area_intersect / area_union

    return torch.mean(iou)


def make_grid(
    grid_dims: tuple[int] | list[int],
    batch_size: int,
    coord_range: tuple[int] | list[int] | list[list] = (-1, 1),
    flatten: bool = True,
) -> torch.Tensor:
    if isinstance(coord_range[0], int):
        coord_range = [coord_range] * len(grid_dims)
    linspaces = [
        torch.linspace(coord_range[i][0], coord_range[i][1], dim)
        for i, dim in enumerate(grid_dims)
    ]
    grid = torch.stack(torch.meshgrid(*linspaces, indexing="ij"), dim=-1)
    if flatten:
        grid = rearrange(grid, "... c -> (...) c")
    batch_grid = repeat(grid, "... -> b ...", b=batch_size)
    return batch_grid


def make_image_grid(
    shape: tuple[int] | list[int],
    batch_size: int,
    coord_range: tuple[int] | list[int] = (-1, 1),
    flatten: bool = True,
) -> torch.Tensor:
    """This grid corresponds the a row-wise flattened image, starting from the top-left
    corner. This way you can plot a scatter plot of the image coordinates and the
    corresponding pixel values"""
    linspaces = [
        torch.linspace(coord_range[0], coord_range[1], shape[0]),
        torch.linspace(coord_range[1], coord_range[0], shape[1]),
    ]
    grid = torch.stack(torch.meshgrid(*linspaces, indexing="xy"), dim=-1)
    if flatten:
        grid = rearrange(grid, "... c -> (...) c")
    batch_grid = repeat(grid, "... -> b ...", b=batch_size)
    return batch_grid
