# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import PIL.Image
import torch

from diffusers.utils import load_image
from diffusers.utils.constants import (
    DECODE_ENDPOINT_FLUX,
    DECODE_ENDPOINT_SD_V1,
    DECODE_ENDPOINT_SD_XL,
    ENCODE_ENDPOINT_FLUX,
    ENCODE_ENDPOINT_SD_V1,
    ENCODE_ENDPOINT_SD_XL,
)
from diffusers.utils.remote_utils import (
    remote_decode,
    remote_encode,
)
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    slow,
)


enable_full_determinism()

IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true"


class RemoteAutoencoderKLEncodeMixin:
    channels: int = None
    endpoint: str = None
    decode_endpoint: str = None
    dtype: torch.dtype = None
    scaling_factor: float = None
    shift_factor: float = None
    image: PIL.Image.Image = None

    def get_dummy_inputs(self):
        if self.image is None:
            self.image = load_image(IMAGE)
        inputs = {
            "endpoint": self.endpoint,
            "image": self.image,
            "scaling_factor": self.scaling_factor,
            "shift_factor": self.shift_factor,
        }
        return inputs

    def test_image_input(self):
        inputs = self.get_dummy_inputs()
        height, width = inputs["image"].height, inputs["image"].width
        output = remote_encode(**inputs)
        self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
        decoded = remote_decode(
            tensor=output,
            endpoint=self.decode_endpoint,
            scaling_factor=self.scaling_factor,
            shift_factor=self.shift_factor,
            image_format="png",
        )
        self.assertEqual(decoded.height, height)
        self.assertEqual(decoded.width, width)
        # image_slice = torch.from_numpy(np.array(inputs["image"])[0, -3:, -3:].flatten())
        # decoded_slice = torch.from_numpy(np.array(decoded)[0, -3:, -3:].flatten())
        # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?


class RemoteAutoencoderKLSDv1Tests(
    RemoteAutoencoderKLEncodeMixin,
    unittest.TestCase,
):
    channels = 4
    endpoint = ENCODE_ENDPOINT_SD_V1
    decode_endpoint = DECODE_ENDPOINT_SD_V1
    dtype = torch.float16
    scaling_factor = 0.18215
    shift_factor = None


class RemoteAutoencoderKLSDXLTests(
    RemoteAutoencoderKLEncodeMixin,
    unittest.TestCase,
):
    channels = 4
    endpoint = ENCODE_ENDPOINT_SD_XL
    decode_endpoint = DECODE_ENDPOINT_SD_XL
    dtype = torch.float16
    scaling_factor = 0.13025
    shift_factor = None


class RemoteAutoencoderKLFluxTests(
    RemoteAutoencoderKLEncodeMixin,
    unittest.TestCase,
):
    channels = 16
    endpoint = ENCODE_ENDPOINT_FLUX
    decode_endpoint = DECODE_ENDPOINT_FLUX
    dtype = torch.bfloat16
    scaling_factor = 0.3611
    shift_factor = 0.1159


class RemoteAutoencoderKLEncodeSlowTestMixin:
    channels: int = 4
    endpoint: str = None
    decode_endpoint: str = None
    dtype: torch.dtype = None
    scaling_factor: float = None
    shift_factor: float = None
    image: PIL.Image.Image = None

    def get_dummy_inputs(self):
        if self.image is None:
            self.image = load_image(IMAGE)
        inputs = {
            "endpoint": self.endpoint,
            "image": self.image,
            "scaling_factor": self.scaling_factor,
            "shift_factor": self.shift_factor,
        }
        return inputs

    def test_multi_res(self):
        inputs = self.get_dummy_inputs()
        for height in {
            320,
            512,
            640,
            704,
            896,
            1024,
            1208,
            1384,
            1536,
            1608,
            1864,
            2048,
        }:
            for width in {
                320,
                512,
                640,
                704,
                896,
                1024,
                1208,
                1384,
                1536,
                1608,
                1864,
                2048,
            }:
                inputs["image"] = inputs["image"].resize(
                    (
                        width,
                        height,
                    )
                )
                output = remote_encode(**inputs)
                self.assertEqual(list(output.shape), [1, self.channels, height // 8, width // 8])
                decoded = remote_decode(
                    tensor=output,
                    endpoint=self.decode_endpoint,
                    scaling_factor=self.scaling_factor,
                    shift_factor=self.shift_factor,
                    image_format="png",
                )
                self.assertEqual(decoded.height, height)
                self.assertEqual(decoded.width, width)
                decoded.save(f"test_multi_res_{height}_{width}.png")


@slow
class RemoteAutoencoderKLSDv1SlowTests(
    RemoteAutoencoderKLEncodeSlowTestMixin,
    unittest.TestCase,
):
    endpoint = ENCODE_ENDPOINT_SD_V1
    decode_endpoint = DECODE_ENDPOINT_SD_V1
    dtype = torch.float16
    scaling_factor = 0.18215
    shift_factor = None


@slow
class RemoteAutoencoderKLSDXLSlowTests(
    RemoteAutoencoderKLEncodeSlowTestMixin,
    unittest.TestCase,
):
    endpoint = ENCODE_ENDPOINT_SD_XL
    decode_endpoint = DECODE_ENDPOINT_SD_XL
    dtype = torch.float16
    scaling_factor = 0.13025
    shift_factor = None


@slow
class RemoteAutoencoderKLFluxSlowTests(
    RemoteAutoencoderKLEncodeSlowTestMixin,
    unittest.TestCase,
):
    channels = 16
    endpoint = ENCODE_ENDPOINT_FLUX
    decode_endpoint = DECODE_ENDPOINT_FLUX
    dtype = torch.bfloat16
    scaling_factor = 0.3611
    shift_factor = 0.1159
