# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Helper functions for visualizing outputs """

from dataclasses import dataclass
from utils.typing import *

import matplotlib
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import rgb_to_grayscale
from jaxtyping import Bool, Float
from torch import Tensor

from utils import colors

Colormaps = Literal[
    "default", "turbo", "viridis", "magma", "inferno", "cividis", "gray", "pca"
]


@dataclass(frozen=True)
class ColormapOptions:
    """Options for colormap"""

    colormap: Colormaps = "default"
    """ The colormap to use """
    normalize: bool = False
    """ Whether to normalize the input tensor image """
    colormap_min: float = 0
    """ Minimum value for the output colormap """
    colormap_max: float = 1
    """ Maximum value for the output colormap """
    invert: bool = False
    """ Whether to invert the output colormap """


def apply_colormap(
    image: Float[Tensor, "*bs channels"],
    colormap_options: ColormapOptions = ColormapOptions(),
    eps: float = 1e-9,
) -> Float[Tensor, "*bs rgb"]:
    """
    Applies a colormap to a tensor image.
    If single channel, applies a colormap to the image.
    If 3 channel, treats the channels as RGB.
    If more than 3 channel, applies a PCA reduction on the dimensions to 3 channels

    Args:
        image: Input tensor image.
        eps: Epsilon value for numerical stability.

    Returns:
        Tensor with the colormap applied.
    """

    # default for rgb images
    if image.shape[-1] == 3:
        return image

    # rendering depth outputs
    if image.shape[-1] == 1 and torch.is_floating_point(image):
        output = image
        if colormap_options.normalize:
            output = output - torch.min(output)
            output = output / (torch.max(output) + eps)
        output = (
            output * (colormap_options.colormap_max - colormap_options.colormap_min)
            + colormap_options.colormap_min
        )
        output = torch.clip(output, 0, 1)
        if colormap_options.invert:
            output = 1 - output
        return apply_float_colormap(output, colormap=colormap_options.colormap)

    # rendering boolean outputs
    if image.dtype == torch.bool:
        return apply_boolean_colormap(image)

    if image.shape[-1] > 3:
        return apply_pca_colormap(image)

    raise NotImplementedError


def apply_float_colormap(
    image: Float[Tensor, "*bs 1"], colormap: Colormaps = "viridis"
) -> Float[Tensor, "*bs rgb"]:
    """Convert single channel to a color image.

    Args:
        image: Single channel image.
        colormap: Colormap for image.

    Returns:
        Tensor: Colored image with colors in [0, 1]
    """
    if colormap == "default":
        colormap = "turbo"

    image = torch.nan_to_num(image, 0)
    if colormap == "gray":
        return image.repeat(1, 1, 3)
    image = image.clamp(0, 1)
    image_long = (image * 255).long()
    image_long_min = torch.min(image_long)
    image_long_max = torch.max(image_long)
    assert image_long_min >= 0, f"the min value is {image_long_min}"
    assert image_long_max <= 255, f"the max value is {image_long_max}"
    return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[
        image_long[..., 0]
    ]


def apply_depth_colormap(
    depth: Float[Tensor, "*bs 1"],
    accumulation: Optional[Float[Tensor, "*bs 1"]] = None,
    near_plane: Optional[float] = None,
    far_plane: Optional[float] = None,
    colormap_options: ColormapOptions = ColormapOptions(),
) -> Float[Tensor, "*bs rgb"]:
    """Converts a depth image to color for easier analysis.

    Args:
        depth: Depth image.
        accumulation: Ray accumulation used for masking vis.
        near_plane: Closest depth to consider. If None, use min image value.
        far_plane: Furthest depth to consider. If None, use max image value.
        colormap: Colormap to apply.

    Returns:
        Colored depth image with colors in [0, 1]
    """

    near_plane = near_plane or float(torch.min(depth))
    far_plane = far_plane or float(torch.max(depth))

    depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
    depth = torch.clip(depth, 0, 1)
    # depth = torch.nan_to_num(depth, nan=0.0) # TODO(ethan): remove this

    colored_image = apply_colormap(depth, colormap_options=colormap_options)

    if accumulation is not None:
        colored_image = colored_image * accumulation + (1 - accumulation)

    return colored_image


def apply_boolean_colormap(
    image: Bool[Tensor, "*bs 1"],
    true_color: Float[Tensor, "*bs rgb"] = colors.WHITE,
    false_color: Float[Tensor, "*bs rgb"] = colors.BLACK,
) -> Float[Tensor, "*bs rgb"]:
    """Converts a depth image to color for easier analysis.

    Args:
        image: Boolean image.
        true_color: Color to use for True.
        false_color: Color to use for False.

    Returns:
        Colored boolean image
    """

    colored_image = torch.ones(image.shape[:-1] + (3,))
    colored_image[image[..., 0], :] = true_color
    colored_image[~image[..., 0], :] = false_color
    return colored_image


def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb"]:
    """Convert feature image to 3-channel RGB via PCA. The first three principle
    components are used for the color channels, with outlier rejection per-channel

    Args:
        image: image of arbitrary vectors

    Returns:
        Tensor: Colored image
    """
    original_shape = image.shape
    image = image.view(-1, image.shape[-1])
    _, _, v = torch.pca_lowrank(image)
    image = torch.matmul(image, v[..., :3])
    d = torch.abs(image - torch.median(image, dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    m = 3.0  # this is a hyperparam controlling how many std dev outside for outliers
    rins = image[s[:, 0] < m, 0]
    gins = image[s[:, 1] < m, 1]
    bins = image[s[:, 2] < m, 2]

    image[:, 0] -= rins.min()
    image[:, 1] -= gins.min()
    image[:, 2] -= bins.min()

    image[:, 0] /= rins.max() - rins.min()
    image[:, 1] /= gins.max() - gins.min()
    image[:, 2] /= bins.max() - bins.min()

    image = torch.clamp(image, 0, 1)
    image_long = (image * 255).long()
    image_long_min = torch.min(image_long)
    image_long_max = torch.max(image_long)
    assert image_long_min >= 0, f"the min value is {image_long_min}"
    assert image_long_max <= 255, f"the max value is {image_long_max}"
    return image.view(*original_shape[:-1], 3)

def apply_mask_to_images(masks, rgb_imgs):
    # masks: torch.Size([batchsize, 64, 64])
    # rgb_imgs torch.Size([batchsize, 512, 512, 3])
    assert masks.shape[0] == rgb_imgs.shape[0]
    assert masks.shape[1:] == (64, 64)
    assert rgb_imgs.shape[1:] == (512, 512, 3)
    rgb_imgs = rgb_imgs.permute(0, 3, 1, 2) # -> torch.Size([batchsize, 3, 512, 512])
    
    masks = masks.to(dtype=rgb_imgs.dtype, device=rgb_imgs.device)
    masks = masks.unsqueeze(1).expand(-1, 3, -1, -1) # -> torch.Size([batchsize, 3, 64, 64])
    masks = F.interpolate(masks, size=rgb_imgs.shape[-2:], mode='nearest') # -> torch.Size([batchsize, 3, 512, 512])

    
    # rgb_imgs = rgb_imgs.clamp(0, 1)
    # gray_imgs = rgb_to_grayscale(rgb_imgs, num_output_channels=3) # -> torch.Size([4, 3, 512, 512])
    black_imgs = torch.zeros_like(rgb_imgs)
    
    # 将mask应用到图像上
    # 在mask为True的地方保留原图像，在mask为False的地方使用灰度图像
    masked_imgs = torch.where(masks > 0.5, rgb_imgs, black_imgs)
    
    return masked_imgs.permute(0, 2, 3, 1) # -> torch.Size([batchsize, 512, 512, 3])