from typing import Any, Optional

import kornia.color as K
import torch

from ito_vision.parametrizations.ddbm import DDBMParametrization


class DDBMLLIEParametrization(DDBMParametrization):
    def __init__(
        self,
        var_0: float,
        var_1: float,
        cov_01: float,
    ):
        super().__init__(var_0, var_1, cov_01)

    def adjust_lightness(
        self, img: torch.Tensor, target_lightness: torch.Tensor
    ) -> torch.Tensor:
        img = ((img + 1.0) * 0.5).clamp(0, 1)
        img_hsv = K.rgb_to_hsv(img)
        img_hsv[:, -1] = (
            img_hsv[:, -1]
            - img_hsv[:, -1].mean(dim=(-2, -1), keepdim=True)
            + target_lightness.unsqueeze(-1).unsqueeze(-1)
        )
        img = K.hsv_to_rgb(img_hsv)
        return (img * 2.0 - 1.0).clamp(-1, 1)

    def __call__(
        self,
        model: torch.nn.Module,
        xt: torch.Tensor,
        t: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        **kwargs: Any,
    ) -> torch.Tensor:
        out = self.c_skip(t) * xt + self.c_out(t) * model(
            self.c_in(t) * xt, self.c_noise(t), y, **kwargs
        )

        return self.adjust_lightness(out, kwargs["target_lightness"])
