# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

from diffusers.hooks import HookRegistry, ModelHook
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger
from diffusers.utils.testing_utils import CaptureLogger, torch_device


logger = get_logger(__name__)  # pylint: disable=invalid-name


class DummyBlock(torch.nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
        super().__init__()

        self.proj_in = torch.nn.Linear(in_features, hidden_features)
        self.activation = torch.nn.ReLU()
        self.proj_out = torch.nn.Linear(hidden_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj_in(x)
        x = self.activation(x)
        x = self.proj_out(x)
        return x


class DummyModel(torch.nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
        super().__init__()

        self.linear_1 = torch.nn.Linear(in_features, hidden_features)
        self.activation = torch.nn.ReLU()
        self.blocks = torch.nn.ModuleList(
            [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
        )
        self.linear_2 = torch.nn.Linear(hidden_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear_1(x)
        x = self.activation(x)
        for block in self.blocks:
            x = block(x)
        x = self.linear_2(x)
        return x


class AddHook(ModelHook):
    def __init__(self, value: int):
        super().__init__()
        self.value = value

    def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
        logger.debug("AddHook pre_forward")
        args = ((x + self.value) if torch.is_tensor(x) else x for x in args)
        return args, kwargs

    def post_forward(self, module, output):
        logger.debug("AddHook post_forward")
        return output


class MultiplyHook(ModelHook):
    def __init__(self, value: int):
        super().__init__()
        self.value = value

    def pre_forward(self, module, *args, **kwargs):
        logger.debug("MultiplyHook pre_forward")
        args = ((x * self.value) if torch.is_tensor(x) else x for x in args)
        return args, kwargs

    def post_forward(self, module, output):
        logger.debug("MultiplyHook post_forward")
        return output

    def __repr__(self):
        return f"MultiplyHook(value={self.value})"


class StatefulAddHook(ModelHook):
    _is_stateful = True

    def __init__(self, value: int):
        super().__init__()
        self.value = value
        self.increment = 0

    def pre_forward(self, module, *args, **kwargs):
        logger.debug("StatefulAddHook pre_forward")
        add_value = self.value + self.increment
        self.increment += 1
        args = ((x + add_value) if torch.is_tensor(x) else x for x in args)
        return args, kwargs

    def reset_state(self, module):
        self.increment = 0


class SkipLayerHook(ModelHook):
    def __init__(self, skip_layer: bool):
        super().__init__()
        self.skip_layer = skip_layer

    def pre_forward(self, module, *args, **kwargs):
        logger.debug("SkipLayerHook pre_forward")
        return args, kwargs

    def new_forward(self, module, *args, **kwargs):
        logger.debug("SkipLayerHook new_forward")
        if self.skip_layer:
            return args[0]
        return self.fn_ref.original_forward(*args, **kwargs)

    def post_forward(self, module, output):
        logger.debug("SkipLayerHook post_forward")
        return output


class HookTests(unittest.TestCase):
    in_features = 4
    hidden_features = 8
    out_features = 4
    num_layers = 2

    def setUp(self):
        params = self.get_module_parameters()
        self.model = DummyModel(**params)
        self.model.to(torch_device)

    def tearDown(self):
        super().tearDown()

        del self.model
        gc.collect()
        free_memory()

    def get_module_parameters(self):
        return {
            "in_features": self.in_features,
            "hidden_features": self.hidden_features,
            "out_features": self.out_features,
            "num_layers": self.num_layers,
        }

    def get_generator(self):
        return torch.manual_seed(0)

    def test_hook_registry(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(AddHook(1), "add_hook")
        registry.register_hook(MultiplyHook(2), "multiply_hook")

        registry_repr = repr(registry)
        expected_repr = "HookRegistry(\n  (0) add_hook - AddHook\n  (1) multiply_hook - MultiplyHook(value=2)\n)"

        self.assertEqual(len(registry.hooks), 2)
        self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
        self.assertEqual(registry_repr, expected_repr)

        registry.remove_hook("add_hook")

        self.assertEqual(len(registry.hooks), 1)
        self.assertEqual(registry._hook_order, ["multiply_hook"])

    def test_stateful_hook(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(StatefulAddHook(1), "stateful_add_hook")

        self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)

        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
        num_repeats = 3

        for i in range(num_repeats):
            result = self.model(input)
            if i == 0:
                output1 = result

        self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)

        registry.reset_stateful_hooks()
        output2 = self.model(input)

        self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
        self.assertTrue(torch.allclose(output1, output2))

    def test_inference(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(AddHook(1), "add_hook")
        registry.register_hook(MultiplyHook(2), "multiply_hook")

        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
        output1 = self.model(input).mean().detach().cpu().item()

        registry.remove_hook("multiply_hook")
        new_input = input * 2
        output2 = self.model(new_input).mean().detach().cpu().item()

        registry.remove_hook("add_hook")
        new_input = input * 2 + 1
        output3 = self.model(new_input).mean().detach().cpu().item()

        self.assertAlmostEqual(output1, output2, places=5)
        self.assertAlmostEqual(output1, output3, places=5)

    def test_skip_layer_hook(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")

        input = torch.zeros(1, 4, device=torch_device)
        output = self.model(input).mean().detach().cpu().item()
        self.assertEqual(output, 0.0)

        registry.remove_hook("skip_layer_hook")
        registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
        output = self.model(input).mean().detach().cpu().item()
        self.assertNotEqual(output, 0.0)

    def test_skip_layer_internal_block(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
        input = torch.zeros(1, 4, device=torch_device)

        registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
        with self.assertRaises(RuntimeError) as cm:
            self.model(input).mean().detach().cpu().item()
        self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))

        registry.remove_hook("skip_layer_hook")
        output = self.model(input).mean().detach().cpu().item()
        self.assertNotEqual(output, 0.0)

        registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
        registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
        output = self.model(input).mean().detach().cpu().item()
        self.assertNotEqual(output, 0.0)

    def test_invocation_order_stateful_first(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(StatefulAddHook(1), "add_hook")
        registry.register_hook(AddHook(2), "add_hook_2")
        registry.register_hook(MultiplyHook(3), "multiply_hook")

        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())

        logger = get_logger(__name__)
        logger.setLevel("DEBUG")

        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            (
                "MultiplyHook pre_forward\n"
                "AddHook pre_forward\n"
                "StatefulAddHook pre_forward\n"
                "AddHook post_forward\n"
                "MultiplyHook post_forward\n"
            )
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

        registry.remove_hook("add_hook")
        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n")
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

    def test_invocation_order_stateful_middle(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(AddHook(2), "add_hook")
        registry.register_hook(StatefulAddHook(1), "add_hook_2")
        registry.register_hook(MultiplyHook(3), "multiply_hook")

        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())

        logger = get_logger(__name__)
        logger.setLevel("DEBUG")

        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            (
                "MultiplyHook pre_forward\n"
                "StatefulAddHook pre_forward\n"
                "AddHook pre_forward\n"
                "AddHook post_forward\n"
                "MultiplyHook post_forward\n"
            )
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

        registry.remove_hook("add_hook")
        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            ("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n")
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

        registry.remove_hook("add_hook_2")
        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            ("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

    def test_invocation_order_stateful_last(self):
        registry = HookRegistry.check_if_exists_or_initialize(self.model)
        registry.register_hook(AddHook(1), "add_hook")
        registry.register_hook(MultiplyHook(2), "multiply_hook")
        registry.register_hook(StatefulAddHook(3), "add_hook_2")

        input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())

        logger = get_logger(__name__)
        logger.setLevel("DEBUG")

        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            (
                "StatefulAddHook pre_forward\n"
                "MultiplyHook pre_forward\n"
                "AddHook pre_forward\n"
                "AddHook post_forward\n"
                "MultiplyHook post_forward\n"
            )
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)

        registry.remove_hook("add_hook")
        with CaptureLogger(logger) as cap_logger:
            self.model(input)
        output = cap_logger.out.replace(" ", "").replace("\n", "")
        expected_invocation_order_log = (
            ("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n")
            .replace(" ", "")
            .replace("\n", "")
        )
        self.assertEqual(output, expected_invocation_order_log)
