"""
Diffusion model wrappers and implementations for watermarking experiments.

This module provides wrappers around diffusion models (Stable Diffusion) using
diffusers' built-in schedulers and sampling methods.
"""

import random
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from diffusers import StableDiffusionPipeline
from torch import Tensor, nn
from tqdm import tqdm


from .vae import decode
from ..utils.utils import cleanup_cuda_memory
from ..utils.config import Config
from ..utils import utils


class DiffusionModel(nn.Module):
    def __init__(self, config: Config, pipe: Optional[StableDiffusionPipeline] = None):
        super().__init__()
        self.config = config
        if pipe is None:
            self.pipe = StableDiffusionPipeline.from_pretrained(
                self.config.diffusion.stable_diffusion_model_id,
                torch_dtype=self.config.diffusion.torch_dtype
            )
            print('Loaded Stable Diffusion model:', self.config.diffusion.stable_diffusion_model_id)
        else:
            self.pipe = pipe
        self.device = self.config.get_device()
        self.unet = self.pipe.unet.to(self.device).requires_grad_(False).eval()
        self.text_encoder = self.pipe.text_encoder.to(self.device)
        self.tokenizer = self.pipe.tokenizer
        # Use the pipe's default scheduler
        self.scheduler = self.pipe.scheduler
        self.sacc = self.scheduler.alphas_cumprod.sqrt().to(self.device)
        self.unet.compile()

    @torch.no_grad()
    def get_conditioning(self, prompts: List[str]) -> torch.Tensor:
        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length
        )
        input_ids = inputs.input_ids.to(self.device)
        return self.text_encoder(input_ids)[0]


class ModelWrapper(nn.Module):
    def __init__(self, model: nn.Module, steps: Optional[int] = None, cfg_scale: Optional[float] = None):
        super().__init__()
        self.config = model.config
        if steps is None:
            steps = self.config.sampling.default_steps
        if cfg_scale is None:
            cfg_scale = self.config.sampling.cfg_scale
        self.model = model
        self.steps = steps
        self.cfg_scale = cfg_scale
        self.base_shape = self.config.base_latent_shape
        self.device = self.config.get_device()
        
        # Use the model's scheduler (already set up with default hyperparameters)
        self.scheduler = self.model.scheduler
        cleanup_cuda_memory()

    def predict_latent(self, positive_prompts: List[str], negative_prompts: List[str], 
                      x_t: Optional[torch.Tensor] = None, x_ts_ret: bool = False) -> torch.Tensor:
        """Predict latent using diffusers scheduler with proper noise handling."""
        
        # Get text embeddings
        positive_embeds = self.model.get_conditioning(positive_prompts)
        negative_embeds = self.model.get_conditioning(negative_prompts)
        
        # Prepare initial latents (watermark noise is now standard latent noise)
        if x_t is None:
            latents = torch.randn(
                (len(positive_prompts), 4, 64, 64),
                device=self.device,
                dtype=positive_embeds.dtype
            )
        else:
            latents = x_t.to(device=self.device, dtype=positive_embeds.dtype)
        
        # Set timesteps and scale initial noise
        self.scheduler.set_timesteps(self.steps, device=self.device)
        latents = latents * self.scheduler.init_noise_sigma
        
        x_ts = []
        
        # Denoising loop using standard diffusers pattern
        for timestep in self.scheduler.timesteps:
            # Expand latents for CFG
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
            
            # Concatenate embeddings for CFG (unconditional first, then conditional)
            prompt_embeds = torch.cat([negative_embeds, positive_embeds])
            
            # Predict noise
            noise_pred = self.model.unet(
                latent_model_input,
                timestep,
                encoder_hidden_states=prompt_embeds,
                return_dict=False,
            )[0]
            
            # Apply classifier-free guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
            
            # Scheduler step
            latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
            x_ts.append(latents)
                
        if x_ts_ret:
            return x_ts
            
        return latents

    def predict_image(self,
                      positive_prompts: List[str],
                      negative_prompts: List[str],
                      **kwargs) -> Any:
        """
        Predict final decoded images.
        
        Args:
            positive_prompts: List of positive text prompts
            negative_prompts: List of negative text prompts
            **kwargs: Additional arguments for predict_latent
            
        Returns:
            Decoded images
        """
        x_t = self.predict_latent(positive_prompts, negative_prompts, **kwargs)
        return decode(x_t.to(torch.float16))

    def generate_full(self, batch_size: int, noise: Optional[torch.Tensor] = None, 
                      x_ts_ret: bool = False, prompts: Optional[List[str]] = None, 
                      positive_prompts: Optional[List[str]] = None,
                      negative_prompts: Optional[List[str]] = None, ret_prompts: bool = False):
        if not positive_prompts:
            positive_prompts = []
            for _ in range(batch_size):
                rand_prompt = random.randint(0, len(prompts) - 1)
                positive_prompts.append(prompts[rand_prompt])

        if not negative_prompts:
            negative_prompts = [''] * len(positive_prompts)

        # Now that Watermark produces standard noise, we can use it directly
        latents = self.predict_latent(positive_prompts, negative_prompts, noise, x_ts_ret)
        if ret_prompts:
            return latents, positive_prompts, negative_prompts
        else:
            return latents
