from typing import Any, Dict, List

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME


class PipelineCallback(ConfigMixin):
    """
    Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
    custom callbacks and ensures that all callbacks have a consistent interface.

    Please implement the following:
        `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
        include
            variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
        `callback_fn`: This method defines the core functionality of your callback.
    """

    config_name = CONFIG_NAME

    @register_to_config
    def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
        super().__init__()

        if (cutoff_step_ratio is None and cutoff_step_index is None) or (
            cutoff_step_ratio is not None and cutoff_step_index is not None
        ):
            raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")

        if cutoff_step_ratio is not None and (
            not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
        ):
            raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")

    @property
    def tensor_inputs(self) -> List[str]:
        raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")

    def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
        raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")

    def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)


class MultiPipelineCallbacks:
    """
    This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
    provides a unified interface for calling all of them.
    """

    def __init__(self, callbacks: List[PipelineCallback]):
        self.callbacks = callbacks

    @property
    def tensor_inputs(self) -> List[str]:
        return [input for callback in self.callbacks for input in callback.tensor_inputs]

    def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        """
        Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
        """
        for callback in self.callbacks:
            callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)

        return callback_kwargs


class SDCFGCutoffCallback(PipelineCallback):
    """
    Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
    `cutoff_step_index`), this callback will disable the CFG.

    Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
    """

    tensor_inputs = ["prompt_embeds"]

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        cutoff_step_ratio = self.config.cutoff_step_ratio
        cutoff_step_index = self.config.cutoff_step_index

        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        if step_index == cutoff_step:
            prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
            prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.

            pipeline._guidance_scale = 0.0

            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
        return callback_kwargs


class SDXLCFGCutoffCallback(PipelineCallback):
    """
    Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
    `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.

    Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
    """

    tensor_inputs = [
        "prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
    ]

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        cutoff_step_ratio = self.config.cutoff_step_ratio
        cutoff_step_index = self.config.cutoff_step_index

        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        if step_index == cutoff_step:
            prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
            prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.

            add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
            add_text_embeds = add_text_embeds[-1:]  # "-1" denotes the embeddings for conditional pooled text tokens

            add_time_ids = callback_kwargs[self.tensor_inputs[2]]
            add_time_ids = add_time_ids[-1:]  # "-1" denotes the embeddings for conditional added time vector

            pipeline._guidance_scale = 0.0

            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
            callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
            callback_kwargs[self.tensor_inputs[2]] = add_time_ids

        return callback_kwargs


class SDXLControlnetCFGCutoffCallback(PipelineCallback):
    """
    Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
    `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.

    Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
    """

    tensor_inputs = [
        "prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
        "image",
    ]

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        cutoff_step_ratio = self.config.cutoff_step_ratio
        cutoff_step_index = self.config.cutoff_step_index

        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        if step_index == cutoff_step:
            prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
            prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.

            add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
            add_text_embeds = add_text_embeds[-1:]  # "-1" denotes the embeddings for conditional pooled text tokens

            add_time_ids = callback_kwargs[self.tensor_inputs[2]]
            add_time_ids = add_time_ids[-1:]  # "-1" denotes the embeddings for conditional added time vector

            # For Controlnet
            image = callback_kwargs[self.tensor_inputs[3]]
            image = image[-1:]

            pipeline._guidance_scale = 0.0

            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
            callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
            callback_kwargs[self.tensor_inputs[2]] = add_time_ids
            callback_kwargs[self.tensor_inputs[3]] = image

        return callback_kwargs


class IPAdapterScaleCutoffCallback(PipelineCallback):
    """
    Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
    `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.

    Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
    """

    tensor_inputs = []

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        cutoff_step_ratio = self.config.cutoff_step_ratio
        cutoff_step_index = self.config.cutoff_step_index

        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        if step_index == cutoff_step:
            pipeline.set_ip_adapter_scale(0.0)
        return callback_kwargs


class SD3CFGCutoffCallback(PipelineCallback):
    """
    Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
    `cutoff_step_index`), this callback will disable the CFG.

    Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
    """

    tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]

    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
        cutoff_step_ratio = self.config.cutoff_step_ratio
        cutoff_step_index = self.config.cutoff_step_index

        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
        cutoff_step = (
            cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
        )

        if step_index == cutoff_step:
            prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
            prompt_embeds = prompt_embeds[-1:]  # "-1" denotes the embeddings for conditional text tokens.

            pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
            pooled_prompt_embeds = pooled_prompt_embeds[
                -1:
            ]  # "-1" denotes the embeddings for conditional pooled text tokens.

            pipeline._guidance_scale = 0.0

            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
            callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
        return callback_kwargs
