import torch
from diffusers import DPMSolverMultistepScheduler, LMSDiscreteScheduler, DDIMScheduler
from diffusers import StableDiffusionPipeline, StableDiffusion3Pipeline


class StableDiffusionModel(torch.nn.Module):

    def __init__(self, pretrained_model_name_or_path, *args, scheduler=None, **kwargs):
        super().__init__(*args, **kwargs)

        print("Loading Stable Diffusion Model: ", pretrained_model_name_or_path)

        # Load the base pipeline components
        pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path, requires_safety_checker=False, **kwargs
        )
        pipeline.safety_checker = None
        pipeline.requires_safety_checker = False

        # Load the scheduler
        if scheduler:
            scheduler_kwargs = dict(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
                                    num_train_timesteps=1000)

            if scheduler.upper() == 'LMS':
                pipeline.scheduler = LMSDiscreteScheduler(**scheduler_kwargs)
            elif scheduler.upper() == 'DPM':
                pipeline.scheduler = DPMSolverMultistepScheduler(**scheduler_kwargs)
            elif scheduler.upper() == 'DDIM':
                pipeline.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
            else:
                raise NotImplementedError(f"Scheduler {scheduler} is not supported!")

        self.vae = pipeline.vae
        self.unet = pipeline.unet
        self.tokenizer = pipeline.tokenizer
        self.scheduler = pipeline.scheduler
        self.text_encoder = pipeline.text_encoder

        self.pipeline = pipeline

    def encode_prompts(self, prompts, return_uncond_prompt_embeds=True):
        prompt_embeds, uncond_prompt_embeds = self.pipeline.encode_prompt(
                                               prompts,
                                               device=self.pipeline._execution_device,
                                               num_images_per_prompt=1,
                                               do_classifier_free_guidance=True)
        if return_uncond_prompt_embeds:
            return torch.cat([uncond_prompt_embeds, prompt_embeds])
        else:
            return prompt_embeds

    def predict_noise(self, timestep_idx, latents, text_embeddings, guidance_scale=7.5, classifier_free_guidance=True):
        latents = self.scheduler.scale_model_input(latents, self.scheduler.timesteps[timestep_idx])

        # Predict the noise residual
        noise_prediction = self.unet(
            latents, self.scheduler.timesteps[timestep_idx], encoder_hidden_states=text_embeddings
        ).sample

        # Perform guidance
        if classifier_free_guidance:
            noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
            noise_prediction = noise_prediction_uncond + guidance_scale * (
                        noise_prediction_text - noise_prediction_uncond)

        return noise_prediction

    def __call__(self, prompts=None, prompt_embeddings=None, image_size: int = 512, end_at_timestep_idx: int = None,
                 n_inference_steps: int = 50, guidance_scale: float = 7.5, generator=None,
                 return_latents_per_step: bool = False):

        height, width = image_size, image_size

        # Validate the end_at_timestep_idx value
        if end_at_timestep_idx is not None:
            assert end_at_timestep_idx <= n_inference_steps, (
                "end_at_timestep_idx must be less than or equal to n_inference_steps")

        # Latents per step storage if needed
        latents_per_step = []

        def controlled_callback(self, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]

            # Store latents per step if required
            with torch.no_grad():
                latents_per_step.append(latents.clone().detach())

            # Stop when end_at_timestep_idx is reached
            if end_at_timestep_idx is not None and step >= end_at_timestep_idx:
                raise StopIteration

            return {}  # Continue the loop

        n_inference_steps_before = len(self.scheduler.timesteps)

        # Use the pipeline's __call__ method with the custom callback
        with torch.no_grad():
            try:
                _ = self.pipeline(
                        prompt=prompts,
                        prompt_embeds=prompt_embeddings,
                        height=height,
                        width=width,
                        num_inference_steps=n_inference_steps,
                        guidance_scale=guidance_scale,
                        generator=generator,
                        callback_on_step_end=controlled_callback,
                    )
            except StopIteration:
                print("Stopped the inference at timestep: ", end_at_timestep_idx)
                pass

            finally:
                self.scheduler.set_timesteps(n_inference_steps_before)

        last_latent = 1 / 0.18215 * latents_per_step[-1]
        with torch.no_grad():
            image = self.vae.decode(last_latent).sample

        if return_latents_per_step:
            # return the last image and its latent
            return image, latents_per_step

        # Return the last image
        return image
