from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils.torch_utils import randn_tensor
import torch
import copy
from typing import Optional, Union, Tuple, List, Callable, Dict, Any
from diffusers import (
    ControlNetModel,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.controlnet import ControlNetOutput
import copy
from diffusers.models.controlnet import ControlNetConditioningEmbedding
from diffusers import StableDiffusionControlNetPipeline, StableDiffusionControlNetInpaintPipeline
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from numpy import ndarray
from torch import FloatTensor, Generator
import numpy as np
import random
from PIL import Image
from tqdm import tqdm
import copy

class randomTailor():
    def __init__(self, init_large, sample_per_tile=4, min_shift_s=None, device=None) -> None:
        # self.large_latent_width = large_latent_width #[*,*,256,256]
        if device == 'cpu':
            self.device = 'cpu'
        else:
            self.device = 'cuda'
        self.lls = list(init_large.shape[:])
        self.large_latent_width = init_large.shape[-1]
        self.num_tile_w = self.large_latent_width//64
        self.sample_per_tile = 4
        self.tile_hw = self.large_latent_width//self.num_tile_w
        if min_shift_s == None:
            self.min_shift = self.large_latent_width/self.num_tile_w/sample_per_tile
        else:
            self.min_shift = min_shift_s
        # initialize small latent
        self.weight_mask = self.init_weight_mask()
        self.sample_points = []
        tile_all_sample_points = np.array([np.meshgrid(np.linspace(0, sample_per_tile-1, sample_per_tile), np.linspace(0, sample_per_tile-1, sample_per_tile))], dtype=np.uint8)
        self.tile_all_sample_points = tile_all_sample_points.reshape(2, -1).T.tolist()
        self.tile_all_sample_points.remove([0,0])
        self.chosen_samples = self.generate_sample_points()
        self.chosen_samples_control = self.generate_control_tiles_cor()
        self.large_latent = init_large

    def init_weight_mask(self):
        msk_tensor = torch.zeros((self.tile_hw, self.tile_hw)) 
        for i in range(-self.tile_hw//2,self.tile_hw//2):
            for j in range(-self.tile_hw//2,self.tile_hw//2):
                # if i ==-self.tile_hw//2-1 and j == 0:
                #     print(min(abs(i-0),abs(j-0)))
                msk_tensor[j,i]=min(abs(i-0),abs(j-0))
        msk_tensor = (1-0.5)*(msk_tensor-torch.min(msk_tensor))/(torch.max(msk_tensor)-torch.min(msk_tensor)) + 0.5
        return msk_tensor.to(torch.device(self.device))

    # def tail_mask(self, samples):
    #     large_weight_mask = torch.zeros(self.lls)
    #     sample_w = samples[0].shape[-1]
    #     for n, ((i,j), s) in enumerate(zip(self.chosen_samples, samples)):
    #         temp_weight = torch.zeros_like(large_weight_mask)
    #         temp_weight[:,:,i:i+sample_w,j:j+sample_w] = self.weight_mask
    #         large_weight_mask += temp_weight 
    #     large_weight_mask[large_weight_mask>0]=1
    #     return large_weight_mask.to(self.device).type(torch.bfloat16)
            
    def tail_results(self, samples, out_mask=False):
        result_large = torch.zeros(self.lls)
        weight_mask = torch.ones_like(samples[0])*self.weight_mask  # spread weight mask to all dims 
        result_large_weight = torch.zeros(self.lls)
        sample_w = samples[0].shape[-1]
        for n, ((i,j), s) in enumerate(zip(self.chosen_samples, samples)):
            temp_weight = torch.zeros_like(result_large_weight)
            temp_val = torch.zeros_like(result_large)
            temp_weight[:,:,i:i+sample_w,j:j+sample_w] = weight_mask
            temp_val[:,:,i:i+sample_w,j:j+sample_w] = s*weight_mask
            result_large_weight += temp_weight 
            result_large += temp_val
            # output mask image to debug 
            # Image.fromarray(result_large_weight.numpy()[0,0,:,:]*255).convert('L').save('./show_img/show{}.png'.format(n))
        # Image.fromarray(result_large.numpy()[0,0,:,:]*128).convert('L').save('./show.png'.format(n))
        result_large = torch.div(result_large, result_large_weight)
        # Image.fromarray(result_large.numpy()[0,0,:,:]*128).convert('L').save('./show_img/show{}.png'.format(n))
        if out_mask:
            out_weight_mask = result_large_weight
            out_weight_mask[out_weight_mask>0]=1
            out_weight_mask = 1 - out_weight_mask
            
            return result_large.to(self.device).type(torch.bfloat16), out_weight_mask.to(self.device).type(torch.bfloat16)
        else:
            return result_large.to(self.device).type(torch.bfloat16)
            
    def generate_sample_points(self): 
        grids = []
        for ti in range(self.num_tile_w):
            for tj in range(self.num_tile_w):
                if ti == self.num_tile_w-1 or tj == self.num_tile_w-1:
                    grids.append((ti*self.tile_hw,tj*self.tile_hw))
                    if ti !=  self.num_tile_w-1:
                        grids.append((int(ti*self.tile_hw + 1/2*self.tile_hw), tj*self.tile_hw))
                    if tj !=  self.num_tile_w-1:
                        grids.append((ti*self.tile_hw, int(tj*self.tile_hw + 1/2*self.tile_hw)))
                else:
                    grids.append((ti*self.tile_hw,tj*self.tile_hw))
                    grids += [(int(ti*self.tile_hw + rr[0]*self.min_shift), 
                               int(tj*self.tile_hw + rr[1]*self.min_shift))
                              for rr in range(self.tile_all_sample_points, self.sample_per_tile-1)]
        return grids
    
    def generate_control_tiles_cor(self, control_img_width=None):
        if control_img_width is None:
            control_img_width = self.large_latent_width*8
        return tuple([(int(ii*control_img_width/self.large_latent_width), int(jj*control_img_width/self.large_latent_width)) 
                for ii, jj in self.chosen_samples])
    

class PWTT(StableDiffusionControlNetPipeline):
    def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True):
        super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor, requires_safety_checker)
    
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image = None,
        image_H = None,
        # mask_image: PipelineImageInput = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        large_latent_height:int = None,
        large_latent_width:int = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
        """
        # controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
        controlnet = self.controlnet
        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
            control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
                control_guidance_end
            ]

        # # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt,
        #     image,
        #     callback_steps,
        #     negative_prompt,
        #     prompt_embeds,
        #     negative_prompt_embeds,
        #     controlnet_conditioning_scale,
        #     control_guidance_start,
        #     control_guidance_end,
        # )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)

        global_pool_conditions = (
            controlnet.config.global_pool_conditions
            if isinstance(controlnet, ControlNetModel)
            else controlnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        # 5. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        
        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        large_latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            large_latent_height,
            large_latent_width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        # print(large_latents)
        # print(len(large_latents))

        tailor = randomTailor(large_latents, device='cuda')
        # create a bunch of schedulers to enable UNIPC or DPMsolver 
        schedulers = [copy.deepcopy(self.scheduler) for i in tailor.chosen_samples]

        # 7. Prepare extra step kwargs.
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Create tensor stating which controlnets to keep
        controlnet_keep = []
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)

        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                all_tile_latents=[]
                pbar = tqdm(range(len(tailor.chosen_samples))) 
                for (ii,jj), (ii_c,jj_c), sch, pi, in zip(tailor.chosen_samples, tailor.chosen_samples_control, schedulers, pbar, ):
                    pbar.set_description('Processing {}/{}'.format(pi+1, len(tailor.chosen_samples)))
                    if isinstance(controlnet, ControlNetModel):
                        # print(width, height, image.shape, ii_c, jj_c, tailor.tile_hw)
                        image_ROI_p = self.prepare_image(
                            image=image[:,:,ii_c:ii_c+tailor.tile_hw*8, jj_c:jj_c+tailor.tile_hw*8],
                            width=width,
                            height=height,
                            batch_size=batch_size * num_images_per_prompt,
                            num_images_per_prompt=num_images_per_prompt,
                            device=device,
                            dtype=controlnet.dtype,
                            do_classifier_free_guidance=do_classifier_free_guidance,
                            guess_mode=guess_mode,
                        ) # prepare cond1
                        image_H_p = self.prepare_image(
                            image=image_H[:,:,ii_c:ii_c+tailor.tile_hw*8, jj_c:jj_c+tailor.tile_hw*8],
                            width=width,
                            height=height,
                            batch_size=batch_size * num_images_per_prompt,
                            num_images_per_prompt=num_images_per_prompt,
                            device=device,
                            dtype=controlnet.dtype,
                            do_classifier_free_guidance=do_classifier_free_guidance,
                            guess_mode=guess_mode,
                        )# prepare cond2
                        height, width = image_ROI_p.shape[-2:]
                    else:
                        assert False, 're-check conditioning input'
                    
                    # expand the latents if we are doing classifier free guidance
                    # latent_model_input = torch.cat([large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]] * 2) if do_classifier_free_guidance else latents
                    latents = large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    latent_model_input = sch.scale_model_input(latent_model_input, t)

                    # controlnet(s) inference
                    if guess_mode and do_classifier_free_guidance:
                        # Infer ControlNet only for the conditional batch.
                        control_model_input = latents
                        control_model_input = sch.scale_model_input(control_model_input, t)
                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                    else:
                        control_model_input = latent_model_input
                        controlnet_prompt_embeds = prompt_embeds

                    if isinstance(controlnet_keep[i], list):
                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                    else:
                        controlnet_cond_scale = controlnet_conditioning_scale
                        if isinstance(controlnet_cond_scale, list):
                            controlnet_cond_scale = controlnet_cond_scale[0]
                        cond_scale = controlnet_cond_scale * controlnet_keep[i]

                    # print(latent_model_input.shape)
                    down_block_res_samples, mid_block_res_sample = self.controlnet(
                        control_model_input,
                        t,
                        encoder_hidden_states=controlnet_prompt_embeds,
                        controlnet_cond=image_ROI_p,
                        controlnet_cond2=image_H_p,
                        conditioning_scale=cond_scale,
                        guess_mode=guess_mode,
                        return_dict=False,
                    )

                    if guess_mode and do_classifier_free_guidance:
                        # Infered ControlNet only for the conditional batch.
                        # To apply the output of ControlNet to both the unconditional and conditional batches,
                        # add 0 to the unconditional batch to keep it unchanged.
                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                        mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

                    # predict the noise residual
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                        return_dict=False,
                    )[0]


                    # perform guidance
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    xx = sch.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                    all_tile_latents.append(xx)
                    # print(xx.shape)

                    # print(t)
                    # all_tile_latents.append(latents)
                large_latents = tailor.tail_results(all_tile_latents)
                # print(large_latents.shape)
                # assert False
                # tailor.chosen_samples = tailor.init_sample_points()
                # print(tailor.chosen_samples)
                # print(large_latents.shape)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, large_latents)

        # If we do sequential model offloading, let's offload unet and controlnet
        # manually for max memory savings
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.unet.to("cpu")
            self.controlnet.to("cpu")
            torch.cuda.empty_cache()

        image = self.vae.decode(large_latents / self.vae.config.scaling_factor, return_dict=False)[0]
        do_denormalize = [True] * image.shape[0]
        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)

class RandomSlideUpscaleX4Pipeline(StableDiffusionControlNetPipeline):
    def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True):
        super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor, requires_safety_checker)
    
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        large_latent_height:int = None,
        large_latent_width:int = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
        """
        # controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
        controlnet = self.controlnet
        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
            control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
                control_guidance_end
            ]

        # # 1. Check inputs. Raise error if not correct
        # self.check_inputs(
        #     prompt,
        #     image,
        #     callback_steps,
        #     negative_prompt,
        #     prompt_embeds,
        #     negative_prompt_embeds,
        #     controlnet_conditioning_scale,
        #     control_guidance_start,
        #     control_guidance_end,
        # )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)

        global_pool_conditions = (
            controlnet.config.global_pool_conditions
            if isinstance(controlnet, ControlNetModel)
            else controlnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        # 5. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        
        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        large_latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            large_latent_height,
            large_latent_width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        # print(large_latents)
        # print(len(large_latents))

        tailor = randomTailor(large_latents, device='cuda')
        # create a bunch of schedulers to enable UNIPC or DPMsolver 
        schedulers = [copy.deepcopy(self.scheduler) for i in tailor.chosen_samples]

        # 7. Prepare extra step kwargs.
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Create tensor stating which controlnets to keep
        controlnet_keep = []
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)

        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                all_tile_latents=[]
                pbar = tqdm(range(len(tailor.chosen_samples))) 
                for (ii,jj), (ii_c,jj_c), sch, pi, in zip(tailor.chosen_samples, tailor.chosen_samples_control, schedulers, pbar, ):
                    pbar.set_description('Processing {}/{}'.format(pi+1, len(tailor.chosen_samples)))
                    if isinstance(controlnet, ControlNetModel):
                        # print(width, height, image.shape, ii_c, jj_c, tailor.tile_hw)
                        image_ROI_p = self.prepare_image(
                            image=image[:,:,ii_c:ii_c+tailor.tile_hw*8, jj_c:jj_c+tailor.tile_hw*8],
                            width=width,
                            height=height,
                            batch_size=batch_size * num_images_per_prompt,
                            num_images_per_prompt=num_images_per_prompt,
                            device=device,
                            dtype=controlnet.dtype,
                            do_classifier_free_guidance=do_classifier_free_guidance,
                            guess_mode=guess_mode,
                        ) # prepare cond1
                        height, width = image_ROI_p.shape[-2:]
                    else:
                        assert False, 're-check conditioning input'
                    
                    # expand the latents if we are doing classifier free guidance
                    # latent_model_input = torch.cat([large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]] * 2) if do_classifier_free_guidance else latents
                    latents = large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    latent_model_input = sch.scale_model_input(latent_model_input, t)

                    # controlnet(s) inference
                    if guess_mode and do_classifier_free_guidance:
                        # Infer ControlNet only for the conditional batch.
                        control_model_input = latents
                        control_model_input = sch.scale_model_input(control_model_input, t)
                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                    else:
                        control_model_input = latent_model_input
                        controlnet_prompt_embeds = prompt_embeds

                    if isinstance(controlnet_keep[i], list):
                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                    else:
                        controlnet_cond_scale = controlnet_conditioning_scale
                        if isinstance(controlnet_cond_scale, list):
                            controlnet_cond_scale = controlnet_cond_scale[0]
                        cond_scale = controlnet_cond_scale * controlnet_keep[i]

                    # print(latent_model_input.shape)
                    down_block_res_samples, mid_block_res_sample = self.controlnet(
                        control_model_input,
                        t,
                        encoder_hidden_states=controlnet_prompt_embeds,
                        controlnet_cond=image_ROI_p,
                        conditioning_scale=cond_scale,
                        guess_mode=guess_mode,
                        return_dict=False,
                    )

                    if guess_mode and do_classifier_free_guidance:
                        # Infered ControlNet only for the conditional batch.
                        # To apply the output of ControlNet to both the unconditional and conditional batches,
                        # add 0 to the unconditional batch to keep it unchanged.
                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                        mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

                    # predict the noise residual
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                        return_dict=False,
                    )[0]


                    # perform guidance
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    xx = sch.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                    all_tile_latents.append(xx)
                    # print(xx.shape)

                    # print(t)
                    # all_tile_latents.append(latents)
                large_latents = tailor.tail_results(all_tile_latents)
                # print(large_latents.shape)
                # assert False
                # tailor.chosen_samples = tailor.init_sample_points()
                # print(tailor.chosen_samples)
                # print(large_latents.shape)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, large_latents)

        # If we do sequential model offloading, let's offload unet and controlnet
        # manually for max memory savings
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.unet.to("cpu")
            self.controlnet.to("cpu")
            torch.cuda.empty_cache()

        image = self.vae.decode(large_latents / self.vae.config.scaling_factor, return_dict=False)[0]
        do_denormalize = [True] * image.shape[0]
        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)

class RandomSlideUpscaleInpaintPipeline(StableDiffusionControlNetInpaintPipeline):
    def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True):
        super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor, requires_safety_checker)
    
    
    def prepare_init_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma
        return latents
    
    def prepare_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
        image_batch_size = image.shape[0]

        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image
    
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image = None,
        image_H = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        large_latent_height:int = None,
        large_latent_width:int = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        guess_mode: bool = False,
        control_guidance_start: Union[float, List[float]] = 0.0,
        control_guidance_end: Union[float, List[float]] = 1.0,
        strength=1.0,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
        """
        # controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
        controlnet = self.controlnet
        # align format for control guidance
        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
            control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
                control_guidance_end
            ]

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)

        global_pool_conditions = (
            controlnet.config.global_pool_conditions
            if isinstance(controlnet, ControlNetModel)
            else controlnet.nets[0].config.global_pool_conditions
        )
        guess_mode = guess_mode or global_pool_conditions

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        
        ####
        # # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
        # init_image = self.image_processor.preprocess(image, height=height, width=width)
        # init_image = init_image.to(dtype=torch.float32)

        # mask = self.mask_processor.preprocess(mask_image, height=height, width=width)

        # masked_image = init_image * (mask < 0.5)
        # _, _, height, width = init_image.shape

        # 5. Prepare timesteps
        # self.scheduler.set_timesteps(num_inference_steps, device=device)
        # timesteps, num_inference_steps = self.get_timesteps(
        #     num_inference_steps=num_inference_steps, strength=strength, device=device
        # )
        # # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
        # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
        ####
        
        # # 5. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        
        # 6. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        large_latents = self.prepare_init_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            large_latent_height,
            large_latent_width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        tailor = randomTailor(large_latents, device='cuda')

        # create a bunch of schedulers to enable UNIPC or DPMsolver 
        schedulers = [copy.deepcopy(self.scheduler) for _ in tailor.chosen_samples]

        # 7. Prepare extra step kwargs.
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Create tensor stating which controlnets to keep
        controlnet_keep = []
        for i in range(len(timesteps)):
            keeps = [
                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
                for s, e in zip(control_guidance_start, control_guidance_end)
            ]
            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
        
        
        
        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                all_tile_latents=[]
                pbar = tqdm(range(len(tailor.chosen_samples))) 
                for nn,((ii,jj), (ii_c,jj_c), sch, pi,) in enumerate(zip(tailor.chosen_samples, tailor.chosen_samples_control, schedulers, pbar,) ):
                    pbar.set_description('Processing {}/{}'.format(pi+1, len(tailor.chosen_samples)))
                    
                    if isinstance(controlnet, ControlNetModel):
                        # print(width, height, image.shape, ii_c, jj_c, tailor.tile_hw)
                        image_ROI_p = self.prepare_image(
                            image=image[:,:,ii_c:ii_c+tailor.tile_hw*8, jj_c:jj_c+tailor.tile_hw*8],
                            width=width,
                            height=height,
                            batch_size=batch_size * num_images_per_prompt,
                            num_images_per_prompt=num_images_per_prompt,
                            device=device,
                            dtype=controlnet.dtype,
                            do_classifier_free_guidance=do_classifier_free_guidance,
                            guess_mode=guess_mode,
                        )
                        image_H_p = self.prepare_image(
                            image=image_H[:,:,ii_c:ii_c+tailor.tile_hw*8, jj_c:jj_c+tailor.tile_hw*8],
                            width=width,
                            height=height,
                            batch_size=batch_size * num_images_per_prompt,
                            num_images_per_prompt=num_images_per_prompt,
                            device=device,
                            dtype=controlnet.dtype,
                            do_classifier_free_guidance=do_classifier_free_guidance,
                            guess_mode=guess_mode,
                        )
                        height, width = image_ROI_p.shape[-2:]
                    else:
                        assert False, 're-check conditioning input'
                    
                    # expand the latents if we are doing classifier free guidance
                    # latent_model_input = torch.cat([large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]] * 2) if do_classifier_free_guidance else latents
                    latents = large_latents[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]
                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                    latent_model_input = sch.scale_model_input(latent_model_input, t)

                    # controlnet(s) inference
                    if guess_mode and do_classifier_free_guidance:
                        # Infer ControlNet only for the conditional batch.
                        control_model_input = latents
                        control_model_input = sch.scale_model_input(control_model_input, t)
                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
                    else:
                        control_model_input = latent_model_input
                        controlnet_prompt_embeds = prompt_embeds

                    if isinstance(controlnet_keep[i], list):
                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
                    else:
                        controlnet_cond_scale = controlnet_conditioning_scale
                        if isinstance(controlnet_cond_scale, list):
                            controlnet_cond_scale = controlnet_cond_scale[0]
                        cond_scale = controlnet_cond_scale * controlnet_keep[i]

                    # print(latent_model_input.shape)
                    down_block_res_samples, mid_block_res_sample = self.controlnet(
                        control_model_input,
                        t,
                        encoder_hidden_states=controlnet_prompt_embeds,
                        controlnet_cond=image_ROI_p,
                        controlnet_cond2=image_H_p,
                        conditioning_scale=cond_scale,
                        guess_mode=guess_mode,
                        return_dict=False,
                    )

                    if guess_mode and do_classifier_free_guidance:
                        # Infered ControlNet only for the conditional batch.
                        # To apply the output of ControlNet to both the unconditional and conditional batches,
                        # add 0 to the unconditional batch to keep it unchanged.
                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
                        mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

                    # predict the noise residual
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                        return_dict=False,
                    )[0]
                    
                    # perform guidance
                    if do_classifier_free_guidance:
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the previous noisy sample x_t -> x_t-1
                    tile_latent = sch.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                    # mask outpaint 
                    if nn > 0:
                        large_latents_m, large_mask = tailor.tail_results(all_tile_latents, out_mask=True)
                        large_latents_m[large_mask==1] = large_latents[large_mask==1]
                        init_latents_proper = large_latents_m[:, :, ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]
                        tile_mask_latent = large_mask[:,:,ii:ii+tailor.tile_hw, jj:jj+tailor.tile_hw]
                        # Image.fromarray(tile_mask_latent.type(torch.float16).cpu().numpy()[0,0,:,:]*255).convert('L').save('./show_img/show{}.png'.format(nn))
                        
                        masked_tile_latent = (tile_mask_latent) * init_latents_proper + (1-tile_mask_latent) * tile_latent
                        all_tile_latents.append(masked_tile_latent)
                    else:
                        all_tile_latents.append(tile_latent)
                    
                large_latents = tailor.tail_results(all_tile_latents)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, large_latents)

        # If we do sequential model offloading, let's offload unet and controlnet
        # manually for max memory savings
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.unet.to("cpu")
            self.controlnet.to("cpu")
            torch.cuda.empty_cache()

        image = self.vae.decode(large_latents / self.vae.config.scaling_factor, return_dict=False)[0]
        do_denormalize = [True] * image.shape[0]
        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()
        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)

if __name__ == '__main__':
    xx = randomTailor(256,)
    print(xx.chosen_samples)
    print(len(xx.chosen_samples))
    # print(xx.weight_mask)
    tx = torch.ones([2,2,256,256])
    ti = [torch.ones(2,2,64,64) for i in range(len(xx.chosen_samples))]
    xx.tail_results(tx, ti)
    print('done')
