# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import tempfile
import unittest
from typing import List

import numpy as np
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    FluxTransformer2DModel,
    TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
from diffusers.utils.testing_utils import (
    enable_full_determinism,
    is_torch_available,
    is_torchao_available,
    nightly,
    numpy_cosine_similarity_distance,
    require_torch,
    require_torch_gpu,
    require_torchao_version_greater_or_equal,
    slow,
    torch_device,
)


enable_full_determinism()


if is_torch_available():
    import torch
    import torch.nn as nn

    from ..utils import LoRALayer, get_memory_consumption_stat


if is_torchao_available():
    from torchao.dtypes import AffineQuantizedTensor
    from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
    from torchao.quantization.quant_primitives import MappingType
    from torchao.utils import get_model_size_in_bytes


@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
    def test_to_dict(self):
        """
        Makes sure the config format is properly set
        """
        quantization_config = TorchAoConfig("int4_weight_only")
        torchao_orig_config = quantization_config.to_dict()

        for key in torchao_orig_config:
            self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])

    def test_post_init_check(self):
        """
        Test kwargs validations in TorchAoConfig
        """
        _ = TorchAoConfig("int4_weight_only")
        with self.assertRaisesRegex(ValueError, "is not supported yet"):
            _ = TorchAoConfig("uint8")

        with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
            _ = TorchAoConfig("int4_weight_only", group_size1=32)

    def test_repr(self):
        """
        Check that there is no error in the repr
        """
        quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
        expected_repr = """TorchAoConfig {
            "modules_to_not_convert": [
                "conv"
            ],
            "quant_method": "torchao",
            "quant_type": "int4_weight_only",
            "quant_type_kwargs": {
                "group_size": 8
            }
        }""".replace(" ", "").replace("\n", "")
        quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
        self.assertEqual(quantization_repr, expected_repr)

        quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
        expected_repr = """TorchAoConfig {
            "modules_to_not_convert": null,
            "quant_method": "torchao",
            "quant_type": "int4dq",
            "quant_type_kwargs": {
                "act_mapping_type": "SYMMETRIC",
                "group_size": 64
            }
        }""".replace(" ", "").replace("\n", "")
        quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
        self.assertEqual(quantization_repr, expected_repr)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()

    def get_dummy_components(
        self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
    ):
        transformer = FluxTransformer2DModel.from_pretrained(
            model_id,
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )
        text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
        text_encoder_2 = T5EncoderModel.from_pretrained(
            model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
        )
        tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
        tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
        vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
        scheduler = FlowMatchEulerDiscreteScheduler()

        return {
            "scheduler": scheduler,
            "text_encoder": text_encoder,
            "text_encoder_2": text_encoder_2,
            "tokenizer": tokenizer,
            "tokenizer_2": tokenizer_2,
            "transformer": transformer,
            "vae": vae,
        }

    def get_dummy_inputs(self, device: torch.device, seed: int = 0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator().manual_seed(seed)

        inputs = {
            "prompt": "an astronaut riding a horse in space",
            "height": 32,
            "width": 32,
            "num_inference_steps": 2,
            "output_type": "np",
            "generator": generator,
        }

        return inputs

    def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
        batch_size = 1
        num_latent_channels = 4
        num_image_channels = 3
        height = width = 4
        sequence_length = 48
        embedding_dim = 32

        torch.manual_seed(seed)
        hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)

        torch.manual_seed(seed)
        encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
            device, dtype=torch.bfloat16
        )

        torch.manual_seed(seed)
        pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)

        torch.manual_seed(seed)
        text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)

        torch.manual_seed(seed)
        image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)

        timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)

        return {
            "hidden_states": hidden_states,
            "encoder_hidden_states": encoder_hidden_states,
            "pooled_projections": pooled_prompt_embeds,
            "txt_ids": text_ids,
            "img_ids": image_ids,
            "timestep": timestep,
        }

    def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str):
        components = self.get_dummy_components(quantization_config, model_id)
        pipe = FluxPipeline(**components)
        pipe.to(device=torch_device)

        inputs = self.get_dummy_inputs(torch_device)
        output = pipe(**inputs)[0]
        output_slice = output[-1, -1, -3:, -3:].flatten()

        self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

    def test_quantization(self):
        for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
            # fmt: off
            QUANTIZATION_TYPES_TO_TEST = [
                ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
                ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
                ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
                ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
                ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
                ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
            ]

            if TorchAoConfig._is_cuda_capability_atleast_8_9():
                QUANTIZATION_TYPES_TO_TEST.extend([
                    ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
                    ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
                    # =====
                    # The following lead to an internal torch error:
                    #    RuntimeError: mat2 shape (32x4 must be divisible by 16
                    # Skip these for now; TODO(aryan): investigate later
                    # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
                    # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
                    # =====
                    # Cutlass fails to initialize for below
                    # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
                    # =====
                    ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
                    ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
                ])
            # fmt: on

            for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
                quant_kwargs = {}
                if quantization_name in ["uint4wo", "uint7wo"]:
                    # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
                    quant_kwargs.update({"group_size": 16})
                quantization_config = TorchAoConfig(
                    quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
                )
                self._test_quant_type(quantization_config, expected_slice, model_id)

    def test_int4wo_quant_bfloat16_conversion(self):
        """
        Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
        """
        quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )

        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, AffineQuantizedTensor))
        self.assertEqual(weight.quant_min, 0)
        self.assertEqual(weight.quant_max, 15)

    def test_device_map(self):
        """
        Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
        The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
        correctly set (in the `hf_device_map` attribute of the model).
        """
        custom_device_map_dict = {
            "time_text_embed": torch_device,
            "context_embedder": torch_device,
            "x_embedder": torch_device,
            "transformer_blocks.0": "cpu",
            "single_transformer_blocks.0": "disk",
            "norm_out": torch_device,
            "proj_out": "cpu",
        }
        device_maps = ["auto", custom_device_map_dict]

        inputs = self.get_dummy_tensor_inputs(torch_device)
        # requires with different expected slices since models are different due to offload (we don't quantize modules offloaded to cpu/disk)
        expected_slice_auto = np.array(
            [
                0.34179688,
                -0.03613281,
                0.01428223,
                -0.22949219,
                -0.49609375,
                0.4375,
                -0.1640625,
                -0.66015625,
                0.43164062,
            ]
        )
        expected_slice_offload = np.array(
            [0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688]
        )
        for device_map in device_maps:
            if device_map == "auto":
                expected_slice = expected_slice_auto
            else:
                expected_slice = expected_slice_offload
            with tempfile.TemporaryDirectory() as offload_folder:
                quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
                quantized_model = FluxTransformer2DModel.from_pretrained(
                    "hf-internal-testing/tiny-flux-pipe",
                    subfolder="transformer",
                    quantization_config=quantization_config,
                    device_map=device_map,
                    torch_dtype=torch.bfloat16,
                    offload_folder=offload_folder,
                )

                weight = quantized_model.transformer_blocks[0].ff.net[2].weight

                # Note that when performing cpu/disk offload, the offloaded weights are not quantized, only the weights on the gpu.
                # This is not the case when the model are already quantized
                if "transformer_blocks.0" in device_map:
                    self.assertTrue(isinstance(weight, nn.Parameter))
                else:
                    self.assertTrue(isinstance(weight, AffineQuantizedTensor))

                output = quantized_model(**inputs)[0]
                output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
                self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

            with tempfile.TemporaryDirectory() as offload_folder:
                quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
                quantized_model = FluxTransformer2DModel.from_pretrained(
                    "hf-internal-testing/tiny-flux-sharded",
                    subfolder="transformer",
                    quantization_config=quantization_config,
                    device_map=device_map,
                    torch_dtype=torch.bfloat16,
                    offload_folder=offload_folder,
                )

                weight = quantized_model.transformer_blocks[0].ff.net[2].weight
                if "transformer_blocks.0" in device_map:
                    self.assertTrue(isinstance(weight, nn.Parameter))
                else:
                    self.assertTrue(isinstance(weight, AffineQuantizedTensor))

                output = quantized_model(**inputs)[0]
                output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
                self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

    def test_modules_to_not_convert(self):
        quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
        quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )

        unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
        self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
        self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
        self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)

        quantized_layer = quantized_model_with_not_convert.proj_out
        self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))

        quantization_config = TorchAoConfig("int8_weight_only")
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )

        size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
        size_quantized = get_model_size_in_bytes(quantized_model)

        self.assertTrue(size_quantized < size_quantized_with_not_convert)

    def test_training(self):
        quantization_config = TorchAoConfig("int8_weight_only")
        quantized_model = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/tiny-flux-pipe",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        ).to(torch_device)

        for param in quantized_model.parameters():
            # freeze the model as only adapter layers will be trained
            param.requires_grad = False
            if param.ndim == 1:
                param.data = param.data.to(torch.float32)

        for _, module in quantized_model.named_modules():
            if isinstance(module, Attention):
                module.to_q = LoRALayer(module.to_q, rank=4)
                module.to_k = LoRALayer(module.to_k, rank=4)
                module.to_v = LoRALayer(module.to_v, rank=4)

        with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
            inputs = self.get_dummy_tensor_inputs(torch_device)
            output = quantized_model(**inputs)[0]
            output.norm().backward()

        for module in quantized_model.modules():
            if isinstance(module, LoRALayer):
                self.assertTrue(module.adapter[1].weight.grad is not None)
                self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)

    @nightly
    def test_torch_compile(self):
        r"""Test that verifies if torch.compile works with torchao quantization."""
        for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
            quantization_config = TorchAoConfig("int8_weight_only")
            components = self.get_dummy_components(quantization_config, model_id=model_id)
            pipe = FluxPipeline(**components)
            pipe.to(device=torch_device)

            inputs = self.get_dummy_inputs(torch_device)
            normal_output = pipe(**inputs)[0].flatten()[-32:]

            pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False)
            inputs = self.get_dummy_inputs(torch_device)
            compile_output = pipe(**inputs)[0].flatten()[-32:]

            # Note: Seems to require higher tolerance
            self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))

    def test_memory_footprint(self):
        r"""
        A simple test to check if the model conversion has been done correctly by checking on the
        memory footprint of the converted model and the class type of the linear layers of the converted models
        """
        for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
            transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
            transformer_int4wo_gs32 = self.get_dummy_components(
                TorchAoConfig("int4wo", group_size=32), model_id=model_id
            )["transformer"]
            transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
            transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]

            # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
            for block in transformer_int4wo.transformer_blocks:
                self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor))
                self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor))

            # Will quantize all the linear layers except x_embedder
            for name, module in transformer_int4wo_gs32.named_modules():
                if isinstance(module, nn.Linear) and name not in ["x_embedder"]:
                    self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))

            # Will quantize all the linear layers
            for module in transformer_int8wo.modules():
                if isinstance(module, nn.Linear):
                    self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))

            total_int4wo = get_model_size_in_bytes(transformer_int4wo)
            total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
            total_int8wo = get_model_size_in_bytes(transformer_int8wo)
            total_bf16 = get_model_size_in_bytes(transformer_bf16)

            # TODO: refactor to align with other quantization tests
            # Latter has smaller group size, so more groups -> more scales and zero points
            self.assertTrue(total_int4wo < total_int4wo_gs32)
            # int8 quantizes more layers compare to int4 with default group size
            self.assertTrue(total_int8wo < total_int4wo)
            # int4wo does not quantize too many layers because of default group size, but for the layers it does
            # there is additional overhead of scales and zero points
            self.assertTrue(total_bf16 < total_int4wo)

    def test_model_memory_usage(self):
        model_id = "hf-internal-testing/tiny-flux-pipe"
        expected_memory_saving_ratio = 2.0

        inputs = self.get_dummy_tensor_inputs(device=torch_device)

        transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
        transformer_bf16.to(torch_device)
        unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
        del transformer_bf16

        transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
        transformer_int8wo.to(torch_device)
        quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
        assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio

    def test_wrong_config(self):
        with self.assertRaises(ValueError):
            self.get_dummy_components(TorchAoConfig("int42"))

    def test_sequential_cpu_offload(self):
        r"""
        A test that checks if inference runs as expected when sequential cpu offloading is enabled.
        """
        quantization_config = TorchAoConfig("int8wo")
        components = self.get_dummy_components(quantization_config)
        pipe = FluxPipeline(**components)
        pipe.enable_sequential_cpu_offload()

        inputs = self.get_dummy_inputs(torch_device)
        _ = pipe(**inputs)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
    model_name = "hf-internal-testing/tiny-flux-pipe"

    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()

    def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
        quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
        quantized_model = FluxTransformer2DModel.from_pretrained(
            self.model_name,
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )
        return quantized_model.to(device)

    def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
        batch_size = 1
        num_latent_channels = 4
        num_image_channels = 3
        height = width = 4
        sequence_length = 48
        embedding_dim = 32

        torch.manual_seed(seed)
        hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
        encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
            device, dtype=torch.bfloat16
        )
        pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
        text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
        image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
        timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)

        return {
            "hidden_states": hidden_states,
            "encoder_hidden_states": encoder_hidden_states,
            "pooled_projections": pooled_prompt_embeds,
            "txt_ids": text_ids,
            "img_ids": image_ids,
            "timestep": timestep,
        }

    def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
        quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
        inputs = self.get_dummy_tensor_inputs(torch_device)
        output = quantized_model(**inputs)[0]
        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
        weight = quantized_model.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

    def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
        quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)

        with tempfile.TemporaryDirectory() as tmp_dir:
            quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
            loaded_quantized_model = FluxTransformer2DModel.from_pretrained(
                tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
            ).to(device=torch_device)

        inputs = self.get_dummy_tensor_inputs(torch_device)
        output = loaded_quantized_model(**inputs)[0]

        output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
        self.assertTrue(
            isinstance(
                loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)
            )
        )
        self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

    def test_int_a8w8_cuda(self):
        quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
        expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
        device = "cuda"
        self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
        self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

    def test_int_a16w8_cuda(self):
        quant_method, quant_method_kwargs = "int8_weight_only", {}
        expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
        device = "cuda"
        self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
        self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

    def test_int_a8w8_cpu(self):
        quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
        expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
        device = "cpu"
        self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
        self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)

    def test_int_a16w8_cpu(self):
        quant_method, quant_method_kwargs = "int8_weight_only", {}
        expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
        device = "cpu"
        self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
        self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)


# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()

    def get_dummy_components(self, quantization_config: TorchAoConfig):
        # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
        cache_dir = None
        model_id = "black-forest-labs/FLUX.1-dev"
        transformer = FluxTransformer2DModel.from_pretrained(
            model_id,
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        text_encoder = CLIPTextModel.from_pretrained(
            model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir
        )
        text_encoder_2 = T5EncoderModel.from_pretrained(
            model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir
        )
        tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir)
        tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir)
        vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir)
        scheduler = FlowMatchEulerDiscreteScheduler()

        return {
            "scheduler": scheduler,
            "text_encoder": text_encoder,
            "text_encoder_2": text_encoder_2,
            "tokenizer": tokenizer,
            "tokenizer_2": tokenizer_2,
            "transformer": transformer,
            "vae": vae,
        }

    def get_dummy_inputs(self, device: torch.device, seed: int = 0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator().manual_seed(seed)

        inputs = {
            "prompt": "an astronaut riding a horse in space",
            "height": 512,
            "width": 512,
            "num_inference_steps": 20,
            "output_type": "np",
            "generator": generator,
        }

        return inputs

    def _test_quant_type(self, quantization_config, expected_slice):
        components = self.get_dummy_components(quantization_config)
        pipe = FluxPipeline(**components)
        pipe.enable_model_cpu_offload()

        weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight
        self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))

        inputs = self.get_dummy_inputs(torch_device)
        output = pipe(**inputs)[0].flatten()
        output_slice = np.concatenate((output[:16], output[-16:]))
        self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

    def test_quantization(self):
        # fmt: off
        QUANTIZATION_TYPES_TO_TEST = [
            ("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
            ("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
        ]

        if TorchAoConfig._is_cuda_capability_atleast_8_9():
            QUANTIZATION_TYPES_TO_TEST.extend([
                ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
                ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
            ])
        # fmt: on

        for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
            quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
            self._test_quant_type(quantization_config, expected_slice)
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    def test_serialization_int8wo(self):
        quantization_config = TorchAoConfig("int8wo")
        components = self.get_dummy_components(quantization_config)
        pipe = FluxPipeline(**components)
        pipe.enable_model_cpu_offload()

        weight = pipe.transformer.x_embedder.weight
        self.assertTrue(isinstance(weight, AffineQuantizedTensor))

        inputs = self.get_dummy_inputs(torch_device)
        output = pipe(**inputs)[0].flatten()[:128]

        with tempfile.TemporaryDirectory() as tmp_dir:
            pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False)
            pipe.remove_all_hooks()
            del pipe.transformer
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            transformer = FluxTransformer2DModel.from_pretrained(
                tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
            )
            pipe.transformer = transformer
            pipe.enable_model_cpu_offload()

        weight = transformer.x_embedder.weight
        self.assertTrue(isinstance(weight, AffineQuantizedTensor))

        loaded_output = pipe(**inputs)[0].flatten()[:128]
        # Seems to require higher tolerance depending on which machine it is being run.
        # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of
        # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04,
        # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here.
        self.assertTrue(np.allclose(output, loaded_output, atol=0.06))

    def test_memory_footprint_int4wo(self):
        # The original checkpoints are in bf16 and about 24 GB
        expected_memory_in_gb = 6.0
        quantization_config = TorchAoConfig("int4wo")
        cache_dir = None
        transformer = FluxTransformer2DModel.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
        self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb)

    def test_memory_footprint_int8wo(self):
        # The original checkpoints are in bf16 and about 24 GB
        expected_memory_in_gb = 12.0
        quantization_config = TorchAoConfig("int8wo")
        cache_dir = None
        transformer = FluxTransformer2DModel.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            subfolder="transformer",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3
        self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb)


@require_torch
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
    def tearDown(self):
        gc.collect()
        torch.cuda.empty_cache()

    def get_dummy_inputs(self, device: torch.device, seed: int = 0):
        if str(device).startswith("mps"):
            generator = torch.manual_seed(seed)
        else:
            generator = torch.Generator().manual_seed(seed)

        inputs = {
            "prompt": "an astronaut riding a horse in space",
            "height": 512,
            "width": 512,
            "num_inference_steps": 20,
            "output_type": "np",
            "generator": generator,
        }

        return inputs

    def test_transformer_int8wo(self):
        # fmt: off
        expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703])
        # fmt: on

        # This is just for convenience, so that we can modify it at one place for custom environments and locally testing
        cache_dir = None
        transformer = FluxTransformer2DModel.from_pretrained(
            "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer",
            torch_dtype=torch.bfloat16,
            use_safetensors=False,
            cache_dir=cache_dir,
        )
        pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir
        )
        pipe.enable_model_cpu_offload()

        # Verify that all linear layer weights are quantized
        for name, module in pipe.transformer.named_modules():
            if isinstance(module, nn.Linear):
                self.assertTrue(isinstance(module.weight, AffineQuantizedTensor))

        # Verify outputs match expected slice
        inputs = self.get_dummy_inputs(torch_device)
        output = pipe(**inputs)[0].flatten()
        output_slice = np.concatenate((output[:16], output[-16:]))
        self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
