from typing import Any, Literal, Optional

import kornia.color as K
import torch

from ito_vision.parametrizations.identity import IdentityParametrization


class IdentityLLIEParametrization(IdentityParametrization):
    def __init__(
        self,
        target: Literal["score", "epsilon", "x0"] = "x0",
        loss_weight_type: Literal["t_inv", "var_inv", "uniform"] = "uniform",
        epsilon: float = 1e-4,
    ):
        super().__init__(target, loss_weight_type, epsilon)

        if self.target != "x0":
            raise ValueError("IdentityLLIEparametrization only supports 'x0' target.")

    def adjust_lightness(
        self, img: torch.Tensor, target_lightness: torch.Tensor
    ) -> torch.Tensor:
        img = ((img + 1.0) * 0.5).clamp(0, 1)
        img_hsv = K.rgb_to_hsv(img)
        img_hsv[:, -1] = (
            img_hsv[:, -1]
            - img_hsv[:, -1].mean(dim=(-2, -1), keepdim=True)
            + target_lightness.unsqueeze(-1).unsqueeze(-1)
        )
        img = K.hsv_to_rgb(img_hsv)
        return (img * 2.0 - 1.0).clamp(-1, 1)

    def __call__(
        self,
        model: torch.nn.Module,
        xt: torch.Tensor,
        t: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        out = model(xt, t, y, **kwargs)
        return self.adjust_lightness(out, kwargs["target_lightness"])
