import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm

from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput, StableDiffusionSafetyChecker

from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    deprecate,
    logging,
    replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

from pipe.nethook import TraceDict

class AnyResStableDiffusionPipeline(StableDiffusionPipeline):
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        K: int = 25,
        lpos: Optional = None,
        noise_scale: Optional = True,
        vae_sr: Optional = False,
        guided_lr: Optional = False,
        outpainting: Optional = False,
    ):
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels

        latents_hr = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height * 2,
            width * 2,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        if not outpainting:
            latents = F.interpolate(latents_hr, size=(latents_hr.shape[-2]//2,latents_hr.shape[-1]//2), mode='nearest')
        else:
            p = [l // 4 for l in lpos]
            latents = latents_hr[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]]

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # print(latents.shape)
                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                if do_classifier_free_guidance and guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)


        if not outpainting:
            # low_res_mask = latents.new_zeros((1, 1, *latents.shape[-2:]))
            # p = [l//8 for l in lpos]
            # low_res_mask[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]] = 1
            # timestep_upsample = timesteps[K]
            # b, c, h, w = latents.shape

            # mask = low_res_mask.repeat(b, c, 1, 1).bool()

            # image_hr = F.interpolate(
            #     latents[mask].reshape(-1, c, p[2], p[3]),
            #     size=(h, w), mode='bilinear'
            # )

            # latents_lr = latents
            # latents = image_hr

            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            low_res_mask = image.new_zeros((1, 1, *image.shape[-2:]))
            p = [l for l in lpos]
            low_res_mask[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]] = 1
            timestep_upsample = timesteps[K]
            b, c, h, w = image.shape
            # latents_lr, latents_hr = latents.chunk(2)

            mask = low_res_mask.repeat(b, c, 1, 1).bool()

            outsize = (h*2, w*2) if lpos[-1] >= width else (h, w)
            # print(lpos, outsize)
            image_hr = F.interpolate(
                image[mask].reshape(-1, c, lpos[2], lpos[3]),
                size=outsize, mode='bicubic'
            )

            # latents_lr[mask] = 0.80 * latents_lr[mask] + 0.20 * latents_hr_patch.reshape(-1)
            # latents = torch.cat((latents_lr, latents_hr), dim=0)
            latents_lr = latents
            latents = self.vae.encode(image_hr)[0].sample() * self.vae.config.scaling_factor
            latents_mid = latents.clone()

            if noise_scale:
                p = [l // 4 for l in lpos]
                noise = latents_hr[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]]
            else:
                noise = torch.from_numpy(np.random.randn(*latents.shape)).to(dtype=latents.dtype, device=device)
                # torch.randn(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype).to(device)
            # TODO: fix scheduler
            self.scheduler.set_timesteps(num_inference_steps, device=device)
            # print(latents.shape, noise.shape, image.shape)
            latents = self.scheduler.add_noise(latents, noise, timesteps[K:K+1])

            if guided_lr:
                noise_lr_guided = F.interpolate(noise, size=(noise.shape[-2]//2, noise.shape[-1]//2), mode='nearest')
                latents_lr_guided = self.scheduler.add_noise(latents_lr, noise_lr_guided, timesteps[K:K+1])
        
        else:
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            # image_out = image.new_zeros((1, 1, *image.shape[-2:]))
            # image_out = torch.zeros_like(image)
            image_out = torch.randn_like(image)
            p = [l for l in lpos]

            timestep_upsample = timesteps[K]
            b, c, h, w = image.shape
            # latents_lr, latents_hr = latents.chunk(2)

            # mask = low_res_mask.repeat(b, c, 1, 1).bool()

            inner_size = (lpos[2], lpos[3])
            # print(lpos, outsize)
            image_inner = F.interpolate(
                image,
                size=inner_size, mode='bicubic'
            )
            image_out[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]] = image_inner

            # latents_lr[mask] = 0.80 * latents_lr[mask] + 0.20 * latents_hr_patch.reshape(-1)
            # latents = torch.cat((latents_lr, latents_hr), dim=0)
            latents_lr = latents
            latents = self.vae.encode(image_out)[0].sample() * self.vae.config.scaling_factor

            p = [l // 4 for l in lpos]
            if noise_scale:
                noise = F.interpolate(latents_hr, size=(latents_hr.shape[-2]//2,latents_hr.shape[-1]//2), mode='nearest')
            else:
                noise = torch.from_numpy(np.random.randn(*latents.shape)).to(dtype=latents.dtype, device=device)
                # torch.randn(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype).to(device)
            self.scheduler.set_timesteps(num_inference_steps, device=device)
            # print(latents.shape, noise.shape, image.shape)
            latents = self.scheduler.add_noise(latents, noise, timesteps[K:K+1])

        for i, t in tqdm(enumerate(timesteps[K:]), total=len(timesteps[K:])):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            if do_classifier_free_guidance and guidance_rescale > 0.0:
                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

            if guided_lr:
                noise_lr_pred =  F.interpolate(noise_pred, size=(noise.shape[-2]//2, noise.shape[-1]//2), mode='nearest')
                latents_lr_guided = self.scheduler.step(noise_lr_pred, t, latents_lr_guided, **extra_step_kwargs, return_dict=False)[0]


            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

        # latents = torch.cat((latents, latents_lr), dim=0)
        if not output_type == "latent":

            if guided_lr:
                image_lr_guided = self.vae.decode(latents_lr_guided / self.vae.config.scaling_factor, return_dict=False)[0]

            if vae_sr:
                low_res_mask = latents.new_zeros((1, 1, latents.shape[-2]*8, latents.shape[-1]*8))
                p = [l for l in lpos]
                low_res_mask[:, :, p[0]:p[0]+p[2], p[1]:p[1]+p[3]] = 1
                local_sr, layers = self.vae_sr_fn(low_res_mask, lpos)
                with torch.no_grad(), TraceDict(
                    self.vae,
                    layers=layers,
                    # layers=['unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1'],
                    # layers=['unet.down_blocks.0'],
                    retain_output=True,
                    # retain_input=True,
                    # edit_input=edit_input_attn,
                    edit_output=local_sr,
                    stop=False,
                ) as edited:
                    latents = torch.cat((latents, latents_lr), dim=0)
                    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
                image_lr = self.vae.decode(latents_lr / self.vae.config.scaling_factor, return_dict=False)[0]
            else:
                image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
                image_lr = self.vae.decode(latents_lr / self.vae.config.scaling_factor, return_dict=False)[0]
            image_mid = self.vae.decode(latents_mid / self.vae.config.scaling_factor, return_dict=False)[0]
            # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
            has_nsfw_concept = [False] * image.shape[0]
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
        image_lr = self.image_processor.postprocess(image_lr, output_type=output_type, do_denormalize=do_denormalize)
        image_mid = self.image_processor.postprocess(image_mid, output_type=output_type, do_denormalize=do_denormalize)
        image = image + image_lr + image_mid

        if guided_lr:
            image_lr_guided = self.image_processor.postprocess(image_lr_guided, output_type=output_type, do_denormalize=do_denormalize)
            image += image_lr_guided

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

    @staticmethod
    def vae_sr_fn(low_res_mask, lpos):
        vae_decoder = [
            'decoder.up_blocks.0.resnets.0',
            'decoder.up_blocks.0.resnets.1',
            'decoder.up_blocks.0.resnets.2',
            'decoder.up_blocks.1.resnets.0',
            'decoder.up_blocks.1.resnets.1',
            'decoder.up_blocks.1.resnets.2',
            'decoder.up_blocks.2.resnets.0',
            'decoder.up_blocks.2.resnets.1',
            'decoder.up_blocks.2.resnets.2',
            'decoder.up_blocks.3.resnets.0',
            'decoder.up_blocks.3.resnets.1',
            'decoder.up_blocks.3.resnets.2',
        ]
        layers = vae_decoder

        def local_sr(x, layer):
            if layer in vae_decoder:
                feat = x[0] if type(x) is tuple else x
                _, _, h0, w0 = low_res_mask.shape
                b, c, h, w = feat.shape
                # uncond, cond = feat.chunk(2)
                cond = feat

                mask = F.interpolate(
                    low_res_mask.float(),
                    size=(h, w), mode='nearest').bool()
                mask = mask.repeat(b//2, c, 1, 1)

                hr, lr = cond.chunk(2)
                hr_down = hr
                hr_down = F.interpolate(hr_down, size=(int(lpos[2]*h/h0), int(lpos[3]*w/w0)), mode='bicubic')

                lr_patch = lr[mask].reshape(*hr_down.shape)
                lr_mean = torch.mean(lr_patch, dim=(2, 3), keepdim=True)
                lr_std = torch.std(lr_patch, dim=(2, 3), keepdim=True)
                hr_mean = torch.mean(hr_down, dim=(2, 3), keepdim=True)
                hr_std = torch.std(hr_down, dim=(2, 3), keepdim=True)
                # print(lr_mean, hr_mean)
                hr_down = (hr_down - hr_mean) / hr_std * lr_std + lr_mean
                # print(mask.shape, hr_down.shape)
                # lr[mask] = hr_down.reshape(-1)

                p = [int(l/h0*h) for l in lpos]
                lr, _ = patch_smoothing(lr, hr_down, p)

                out = torch.cat((hr, lr), dim=0)

                return (out, ) + x[1:] if type(x) is tuple else out

            return x
        return local_sr, layers
    
def patch_smoothing(x, patch, lpos):
    # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2).float()/255
    # axes[0].imshow(images[1].permute(1, 2, 0))
    # lr_patch = images[1:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]]
    # lr_patch = F.interpolate(lr_patch, size=lr_patch.shape[2:], mode='bicubic')[0]
    # axes[2].imshow(lr_patch.permute(1, 2, 0))

    # hr_patch = F.interpolate(images[:1], size=(zoom_in_size, zoom_in_size), mode='bicubic')
    # images[1][:, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch.reshape(3, zoom_in_size, zoom_in_size)

    mask = torch.zeros_like(x)
    mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    b, c, h, w = x.shape
    
    kernel_size = 31
    mask_weight = F.conv_transpose2d(mask[:, 0:1, :, :].detach(), x.new_ones([1, c, kernel_size, kernel_size])/(kernel_size**2), bias=None, stride=1, padding=kernel_size//2)
    # mask_weight[mask_weight != 0] = 1 / mask_weight[mask_weight != 0]
    output = torch.zeros_like(x)

    hr_patch = F.interpolate(patch, size=(lpos[2], lpos[3]), mode='bicubic')
    output[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = hr_patch
    output[mask < 1] = x[mask < 1]
    # print(output.shape, mask_weight.shape)
    output = output * mask_weight + x * (1 - mask_weight)

    # mask = torch.zeros_like(x)
    # mask[:, :, lpos[0]:lpos[0]+lpos[2], lpos[1]:lpos[1]+lpos[3]] = 1
    # output[mask < 1] = x[mask < 1]

    return output, mask_weight

@torch.no_grad()
def vae_encode_decode(pipe, image):
    image_recons = image.cuda().half()
    _, c, h, w = image_recons.shape

    latents = pipe.vae.encode(image_recons * 2 - 1)[0].sample() * pipe.vae.config.scaling_factor
    # latents = F.interpolate(latents, size=(latents.shape[-2]*2, latents.shape[-1]*2), mode='bilinear')
    image_recons = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    image_recons = pipe.image_processor.postprocess(image_recons, output_type='pil', do_denormalize=[True]*latents.shape[0])
    image_recons = torch.from_numpy(np.array(image_recons[0])).permute(2, 0, 1).unsqueeze(0).cuda().half() / 255

    return image_recons.cpu().float()