import gc
import tempfile
from io import BytesIO

import requests
import torch
from huggingface_hub import hf_hub_download, snapshot_download

from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor

from ..testing_utils import (
    backend_empty_cache,
    nightly,
    numpy_cosine_similarity_distance,
    require_torch_accelerator,
    torch_device,
)


def download_single_file_checkpoint(repo_id, filename, tmpdir):
    path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir)
    return path


def download_original_config(config_url, tmpdir):
    original_config_file = BytesIO(requests.get(config_url).content)
    path = f"{tmpdir}/config.yaml"
    with open(path, "wb") as f:
        f.write(original_config_file.read())

    return path


def download_diffusers_config(repo_id, tmpdir):
    path = snapshot_download(
        repo_id,
        ignore_patterns=[
            "**/*.ckpt",
            "*.ckpt",
            "**/*.bin",
            "*.bin",
            "**/*.pt",
            "*.pt",
            "**/*.safetensors",
            "*.safetensors",
        ],
        allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
        local_dir=tmpdir,
    )
    return path


@nightly
@require_torch_accelerator
class SingleFileModelTesterMixin:
    def setup_method(self):
        gc.collect()
        backend_empty_cache(torch_device)

    def teardown_method(self):
        gc.collect()
        backend_empty_cache(torch_device)

    def test_single_file_model_config(self):
        pretrained_kwargs = {}
        single_file_kwargs = {}

        if hasattr(self, "subfolder") and self.subfolder:
            pretrained_kwargs["subfolder"] = self.subfolder

        if hasattr(self, "torch_dtype") and self.torch_dtype:
            pretrained_kwargs["torch_dtype"] = self.torch_dtype
            single_file_kwargs["torch_dtype"] = self.torch_dtype

        model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
        model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)

        PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
        for param_name, param_value in model_single_file.config.items():
            if param_name in PARAMS_TO_IGNORE:
                continue
            assert model.config[param_name] == param_value, (
                f"{param_name} differs between pretrained loading and single file loading"
            )

    def test_single_file_model_parameters(self):
        pretrained_kwargs = {}
        single_file_kwargs = {}

        if hasattr(self, "subfolder") and self.subfolder:
            pretrained_kwargs["subfolder"] = self.subfolder

        if hasattr(self, "torch_dtype") and self.torch_dtype:
            pretrained_kwargs["torch_dtype"] = self.torch_dtype
            single_file_kwargs["torch_dtype"] = self.torch_dtype

        model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
        model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)

        state_dict = model.state_dict()
        state_dict_single_file = model_single_file.state_dict()

        assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
            "Model parameters keys differ between pretrained and single file loading"
        )

        for key in state_dict.keys():
            param = state_dict[key]
            param_single_file = state_dict_single_file[key]

            assert param.shape == param_single_file.shape, (
                f"Parameter shape mismatch for {key}: "
                f"pretrained {param.shape} vs single file {param_single_file.shape}"
            )

            assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
                f"Parameter values differ for {key}: "
                f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
            )

    def test_checkpoint_altered_keys_loading(self):
        # Test loading with checkpoints that have altered keys
        if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
            return

        for ckpt_path in self.alternate_keys_ckpt_paths:
            backend_empty_cache(torch_device)

            single_file_kwargs = {}
            if hasattr(self, "torch_dtype") and self.torch_dtype:
                single_file_kwargs["torch_dtype"] = self.torch_dtype

            model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)

            del model
            gc.collect()
            backend_empty_cache(torch_device)


class SDSingleFileTesterMixin:
    single_file_kwargs = {}

    def _compare_component_configs(self, pipe, single_file_pipe):
        for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
            if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
                continue
            assert pipe.text_encoder.config.to_dict()[param_name] == param_value

        PARAMS_TO_IGNORE = [
            "torch_dtype",
            "_name_or_path",
            "architectures",
            "_use_default_values",
            "_diffusers_version",
        ]
        for component_name, component in single_file_pipe.components.items():
            if component_name in single_file_pipe._optional_components:
                continue

            # skip testing transformer based components here
            # skip text encoders / safety checkers since they have already been tested
            if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]:
                continue

            assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
            assert isinstance(component, pipe.components[component_name].__class__), (
                f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
            )

            for param_name, param_value in component.config.items():
                if param_name in PARAMS_TO_IGNORE:
                    continue

                # Some pretrained configs will set upcast attention to None
                # In single file loading it defaults to the value in the class __init__ which is False
                if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
                    pipe.components[component_name].config[param_name] = param_value

                assert pipe.components[component_name].config[param_name] == param_value, (
                    f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
                )

    def test_single_file_components(self, pipe=None, single_file_pipe=None):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, safety_checker=None
        )
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path, safety_checker=None, local_files_only=True
            )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_with_original_config(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
        # Not possible to infer this value when original config is provided
        # we just pass it in here otherwise this test will fail
        upcast_attention = pipe.unet.config.upcast_attention

        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path,
            original_config=self.original_config,
            safety_checker=None,
            upcast_attention=upcast_attention,
        )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_with_original_config_local_files_only(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        # Not possible to infer this value when original config is provided
        # we just pass it in here otherwise this test will fail
        upcast_attention = pipe.unet.config.upcast_attention

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
            local_original_config = download_original_config(self.original_config, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path,
                original_config=local_original_config,
                safety_checker=None,
                upcast_attention=upcast_attention,
                local_files_only=True,
            )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
        sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs)
        sf_pipe.unet.set_attn_processor(AttnProcessor())
        sf_pipe.enable_model_cpu_offload(device=torch_device)

        inputs = self.get_inputs(torch_device)
        image_single_file = sf_pipe(**inputs).images[0]

        pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
        pipe.unet.set_attn_processor(AttnProcessor())
        pipe.enable_model_cpu_offload(device=torch_device)

        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images[0]

        max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())

        assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}"

    def test_single_file_components_with_diffusers_config(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, config=self.repo_id, safety_checker=None
        )
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_with_diffusers_config_local_files_only(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
            local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
            )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_setting_pipeline_dtype_to_fp16(
        self,
        single_file_pipe=None,
    ):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, torch_dtype=torch.float16
        )

        for component_name, component in single_file_pipe.components.items():
            if not isinstance(component, torch.nn.Module):
                continue

            assert component.dtype == torch.float16


class SDXLSingleFileTesterMixin:
    def _compare_component_configs(self, pipe, single_file_pipe):
        # Skip testing the text_encoder for Refiner Pipelines
        if pipe.text_encoder:
            for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
                if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
                    continue
                assert pipe.text_encoder.config.to_dict()[param_name] == param_value

        for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
            if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
                continue
            assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value

        PARAMS_TO_IGNORE = [
            "torch_dtype",
            "_name_or_path",
            "architectures",
            "_use_default_values",
            "_diffusers_version",
        ]
        for component_name, component in single_file_pipe.components.items():
            if component_name in single_file_pipe._optional_components:
                continue

            # skip text encoders since they have already been tested
            if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
                continue

            # skip safety checker if it is not present in the pipeline
            if component_name in ["safety_checker", "feature_extractor"]:
                continue

            assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
            assert isinstance(component, pipe.components[component_name].__class__), (
                f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
            )

            for param_name, param_value in component.config.items():
                if param_name in PARAMS_TO_IGNORE:
                    continue

                # Some pretrained configs will set upcast attention to None
                # In single file loading it defaults to the value in the class __init__ which is False
                if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
                    pipe.components[component_name].config[param_name] = param_value

                assert pipe.components[component_name].config[param_name] == param_value, (
                    f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
                )

    def test_single_file_components(self, pipe=None, single_file_pipe=None):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, safety_checker=None
        )
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        self._compare_component_configs(
            pipe,
            single_file_pipe,
        )

    def test_single_file_components_local_files_only(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path, safety_checker=None, local_files_only=True
            )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_with_original_config(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
        # Not possible to infer this value when original config is provided
        # we just pass it in here otherwise this test will fail
        upcast_attention = pipe.unet.config.upcast_attention
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path,
            original_config=self.original_config,
            safety_checker=None,
            upcast_attention=upcast_attention,
        )

        self._compare_component_configs(
            pipe,
            single_file_pipe,
        )

    def test_single_file_components_with_original_config_local_files_only(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
        # Not possible to infer this value when original config is provided
        # we just pass it in here otherwise this test will fail
        upcast_attention = pipe.unet.config.upcast_attention

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
            local_original_config = download_original_config(self.original_config, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path,
                original_config=local_original_config,
                upcast_attention=upcast_attention,
                safety_checker=None,
                local_files_only=True,
            )

        self._compare_component_configs(
            pipe,
            single_file_pipe,
        )

    def test_single_file_components_with_diffusers_config(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, config=self.repo_id, safety_checker=None
        )
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_components_with_diffusers_config_local_files_only(
        self,
        pipe=None,
        single_file_pipe=None,
    ):
        pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)

        with tempfile.TemporaryDirectory() as tmpdir:
            repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
            local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
            local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)

            single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
                local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
            )

        self._compare_component_configs(pipe, single_file_pipe)

    def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
        sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None)
        sf_pipe.unet.set_default_attn_processor()
        sf_pipe.enable_model_cpu_offload(device=torch_device)

        inputs = self.get_inputs(torch_device)
        image_single_file = sf_pipe(**inputs).images[0]

        pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None)
        pipe.unet.set_default_attn_processor()
        pipe.enable_model_cpu_offload(device=torch_device)

        inputs = self.get_inputs(torch_device)
        image = pipe(**inputs).images[0]

        max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())

        assert max_diff < expected_max_diff

    def test_single_file_setting_pipeline_dtype_to_fp16(
        self,
        single_file_pipe=None,
    ):
        single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
            self.ckpt_path, torch_dtype=torch.float16
        )

        for component_name, component in single_file_pipe.components.items():
            if not isinstance(component, torch.nn.Module):
                continue

            assert component.dtype == torch.float16
