from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from tqdm import tqdm
from config import Range
from torchvision import transforms
from PIL import Image
from utils.position_embedding import PositionalEncodingPermute2D

def affine(sd_feat_src, sd_feat_trg, use_cos_affine=False):
    if use_cos_affine:
        distances = sd_feat_src @ sd_feat_trg.transpose(-2, -1)
        normalized_values = torch.nn.functional.normalize(distances, p=1.0, dim=0, eps=1e-12, out=None)
        values, indices = torch.topk(normalized_values, distances.shape[1]//100, largest=True, dim=0)
        normalized_values = values.unsqueeze(-1)
        feat_src_flattened = sd_feat_src[indices]
        feat_src_flattened = feat_src_flattened * normalized_values
        feat_src_flattened = feat_src_flattened.mean(dim=0)
    else:
        distances = torch.cdist(sd_feat_src, sd_feat_trg)
        _, indices = torch.min(distances, dim=0)
        feat_src_flattened = sd_feat_src[indices]
    return feat_src_flattened

def new_affine(ft_pe_app, ft_pe_struct,ft_app, use_cos_affine=False):
    if use_cos_affine:
        distances = ft_pe_app @ ft_pe_struct.transpose(-2, -1)
        normalized_values = torch.nn.functional.normalize(distances, p=1.0, dim=0, eps=1e-12, out=None)
        values, indices = torch.topk(normalized_values, distances.shape[1]//100, largest=True, dim=0)
        normalized_values = values.unsqueeze(-1)
        feat_src_flattened = ft_app[indices]
        feat_src_flattened = feat_src_flattened * normalized_values
        feat_src_flattened = feat_src_flattened.mean(dim=0)
    else:
        distances = torch.cdist(ft_pe_app, ft_pe_struct)
        _, indices = torch.min(distances, dim=0)
        feat_src_flattened = ft_app[indices]
    return feat_src_flattened

class CrossImageAttentionStableDiffusionPipeline(StableDiffusionPipeline):
    """ A modification of the standard StableDiffusionPipeline to incorporate our cross-image attention."""
    
    def __call__(
            self,
            prompt: Union[str, List[str]] = None,
            height: Optional[int] = None,
            width: Optional[int] = None,
            num_inference_steps: int = 50,
            guidance_scale: float = 7.5,
            negative_prompt: Optional[Union[str, List[str]]] = None,
            num_images_per_prompt: Optional[int] = 1,
            eta: float = 0.0,
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            latents: Optional[torch.FloatTensor] = None,
            prompt_embeds: Optional[torch.FloatTensor] = None,
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
            output_type: Optional[str] = "pil",
            return_dict: bool = True,
            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
            callback_steps: int = 1,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            guidance_rescale: float = 0.0,
            swap_guidance_scale: float = 1.0,
            cross_image_attention_range: Range = Range(10, 90),
            # DDPM addition
            zs: Optional[List[torch.Tensor]] = None,
            model = None,
            # guidance addition
            config = None,
            mask_style = None,
            mask_struct = None,
    ):
        for params in self.unet.parameters():
                params.requires_grad = False
                params.grad = None
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        with torch.no_grad():
            prompt_embeds = self._encode_prompt(
                prompt,
                device,
                num_images_per_prompt,
                do_classifier_free_guidance,
                negative_prompt,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                lora_scale=text_encoder_lora_scale,
            )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps
        t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs[0].shape[0]:])}
        timesteps = timesteps[-zs[0].shape[0]:]

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

        op = tqdm(timesteps[-zs[0].shape[0]:])
        n_timesteps = len(timesteps[-zs[0].shape[0]:])

        count = 0
        if config.attention_guidance_type=="cross":
            attn_type = "attn2"
        elif config.attention_guidance_type=="self":
            attn_type = "attn1"

        if mask_struct is not None and mask_style is not None:
            # mask_style = mask_style.float()
            # mask_struct = mask_struct.float()
            transform = transforms.Compose([
                transforms.Resize((64, 64)),
                transforms.ToTensor()
            ])
            mask_struct = transform(mask_struct).to(device)
            mask_style = transform(mask_style).to(device)
            src_mask = F.interpolate(mask_style.unsqueeze(0), (128,128)).squeeze()
            tar_mask = F.interpolate(mask_struct.unsqueeze(0), (128,128)).squeeze()
        else:
            src_mask = None
            tar_mask = None
        for t in op:
            i = t_to_idx[int(t)]

            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # [3,4,64,64]
            with torch.no_grad():
                noise_pred_no_swap = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs={'perform_swap': False},
                    return_dict=False,
                )[0]
                d_ref_t2attn = {}
                for name, module in self.unet.named_modules():
                    module_name = type(module).__name__
                    if module_name == "Attention" and attn_type in name:
                        attn_mask = module.attn_probs[5] # size is num_channel,s*s,77
                        d_ref_t2attn[name] = attn_mask.detach().cpu()
                        del module.attn_probs

            x0,x1,x2,x3,x4,x5 = latent_model_input.chunk(6,dim=0)
            x3 = x3.detach().requires_grad_(True)
            latent_model_input = torch.cat([x0,x1,x2,x3,x4,x5],dim=0)

            noise_pred_swap = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds.detach(),
                cross_attention_kwargs={'perform_swap': config.swap_kv},
                return_dict=False,
            )[0]
            # compute cross-attentin loss
            cross_attn_loss = 0
            for name, module in self.unet.named_modules():
                module_name = type(module).__name__
                if module_name == "Attention" and attn_type in name:
                    curr = module.attn_probs[3].to(device) # size is num_channel,s*s,77
                    ref = d_ref_t2attn[name].detach().to(device)
                    if config.cross_guidance_type=="cos":
                        cos = nn.CosineSimilarity(dim=1)
                        curr_flatten = curr.view(curr.shape[1], -1).permute(1,0)
                        ref_flatten = ref.view(ref.shape[1], -1).permute(1,0)
                        sim = (cos(curr_flatten, ref_flatten)+1.)/2.
                        cross_attn_loss = cross_attn_loss + 1/(1+4*sim.mean())
                    else:
                        cross_attn_loss += ((curr-ref)**2).sum((1,2)).mean(0)
                    del module.attn_probs
                    
            if do_classifier_free_guidance:
                noise_no_swap_pred_uncond, noise_no_swap_pred_text = noise_pred_no_swap.chunk(2)

                noise_pred = noise_no_swap_pred_uncond + guidance_scale * (
                        noise_no_swap_pred_text - noise_no_swap_pred_uncond)

            else:
                noise_pred = noise_pred_swap
            if t < config.start_guidance_timestep:
                if src_mask is None:
                    print("Get the mask ....")
                    mask_style_32, mask_struct_32, mask_style_64, mask_struct_64 = model.segmentor.get_object_masks()
                    mask_style_32 = mask_style_32.to(torch.float32)
                    mask_struct_32 = mask_struct_32.to(torch.float32)
                    src_mask = F.interpolate(mask_style_32.unsqueeze(0).unsqueeze(0), (128,128)).squeeze()
                    tar_mask = F.interpolate(mask_struct_32.unsqueeze(0).unsqueeze(0), (128,128)).squeeze()

                feat_guidances=self.feat_guidance(t=t,text_embeddings=prompt_embeds,out_latent=x3,app_latent=x4,struct_latent=x5,struct_mask=tar_mask,app_mask=src_mask,config=config)
                feat_guidances = feat_guidances.detach()
                cross_guidances = torch.autograd.grad(cross_attn_loss*config.cross_energy_scale, x3)[0]
                cross_guidances = cross_guidances.detach() #* mask
                noise_pred[:1] = noise_pred[:1] + feat_guidances + cross_guidances * config.cross_guidance
                torch.cuda.empty_cache()

            # DDIM inversion reference
            latents = torch.stack([
                self.perform_ddpm_step(t_to_idx, zs[latent_idx].detach(), latents[latent_idx].detach(), t, noise_pred[latent_idx].detach(), eta)
                for latent_idx in range(latents.shape[0])
            ])

            # # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents.detach())

            count += 1
        with torch.no_grad():
            if not output_type == "latent":
                image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]
                image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
            else:
                image = latents
                has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image.detach(), output_type=output_type, do_denormalize=do_denormalize)

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
    
    def feat_guidance(self, t, text_embeddings, out_latent, app_latent, struct_latent, struct_mask, app_mask, config):
        cos = nn.CosineSimilarity(dim=1)
        # extract the feature
        if config.feat_guidance_type!="app":
            with torch.no_grad():
                up_ft_struct = self.estimator(
                            sample=struct_latent,
                            timestep=t,
                            up_ft_indices=[1,2],
                            encoder_hidden_states=text_embeddings[5:])['up_ft']
                for f_id in range(len(up_ft_struct)):
                    up_ft_struct[f_id] = F.interpolate(up_ft_struct[f_id], (128, 128))
        with torch.no_grad():
            up_ft_app = self.estimator(
                        sample=app_latent,
                        timestep=t,
                        up_ft_indices=[1,2],
                        encoder_hidden_states=text_embeddings[4:5])['up_ft']
            for f_id in range(len(up_ft_app)):
                up_ft_app[f_id] = F.interpolate(up_ft_app[f_id], (128, 128))
        out_latent = out_latent.detach().requires_grad_(True)
        up_ft_out = self.estimator(
                    sample=out_latent,
                    timestep=t,
                    up_ft_indices=[1,2],
                    encoder_hidden_states=text_embeddings[3:4])['up_ft']
        for f_id in range(len(up_ft_out)):
            up_ft_out[f_id] = F.interpolate(up_ft_out[f_id], (128, 128))
        # compute sim loss
        loss_edit = 0
        loss_bg = 0

        for f_id in range(len(up_ft_app)):
            mask_cur = struct_mask > 0.5
            mask_tar = app_mask > 0.5
            mask_struct_bg = struct_mask < 0.5
            mask_app_bg = app_mask < 0.5
            
            p_enc_2d = PositionalEncodingPermute2D(up_ft_out[f_id].shape[1])
            pe_out = p_enc_2d(up_ft_out[f_id])
            pe_app = p_enc_2d(up_ft_app[f_id])
            pe_struct = p_enc_2d(up_ft_struct[f_id])
            
            app_ft_add_pe = up_ft_app[f_id] + pe_app * config.pe_scale
            struct_ft_add_pe = up_ft_struct[f_id] + pe_struct * config.pe_scale
            
            up_ft_cur_vec = up_ft_out[f_id][mask_cur.repeat(1,up_ft_out[f_id].shape[1],1,1)].view(up_ft_out[f_id].shape[1], -1).permute(1,0)
            if config.feat_guidance_type=="app":
                up_ft_app_vec = up_ft_app[f_id][mask_tar.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                affine_vector = affine(up_ft_app_vec,up_ft_cur_vec)
                sim_all=((cos(up_ft_cur_vec, affine_vector)+1.)/2.)
            elif config.feat_guidance_type=="app_new_affine":
                up_ft_struct_vec = up_ft_struct[f_id][mask_cur.repeat(1,up_ft_struct[f_id].shape[1],1,1)].view(up_ft_struct[f_id].shape[1], -1).permute(1,0)
                up_ft_app_vec = up_ft_app[f_id][mask_tar.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                affine_vector = affine(up_ft_app_vec,up_ft_struct_vec)
                sim_all=((cos(up_ft_cur_vec, affine_vector)+1.)/2.)
            elif config.feat_guidance_type=="struct":
                up_ft_struct_vec = up_ft_struct[f_id][mask_cur.repeat(1,up_ft_struct[f_id].shape[1],1,1)].view(up_ft_struct[f_id].shape[1], -1).permute(1,0)
                sim_all=((cos(up_ft_cur_vec, up_ft_struct_vec)+1.)/2.)
            elif config.feat_guidance_type=="app_struct":
                up_ft_struct_vec = up_ft_struct[f_id][mask_cur.repeat(1,up_ft_struct[f_id].shape[1],1,1)].view(up_ft_struct[f_id].shape[1], -1).permute(1,0)
                up_ft_app_vec = up_ft_app[f_id][mask_tar.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                ft_pe_struct_vec = struct_ft_add_pe[mask_cur.repeat(1,up_ft_struct[f_id].shape[1],1,1)].view(up_ft_struct[f_id].shape[1], -1).permute(1,0)
                ft_pe_app_vec = app_ft_add_pe[mask_tar.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                affine_vector = new_affine(ft_pe_app_vec,ft_pe_struct_vec,up_ft_app_vec)
                sim_app=((cos(up_ft_cur_vec, affine_vector)+1.)/2.)
                sim_struct=((cos(up_ft_cur_vec, up_ft_struct_vec)+1.)/2.)
            else:
                raise ValueError(
                        f"Unknown feature guidance type {config.feat_guidance_type}"
                    )
            # compute global feature loss
            up_ft_cur_full = up_ft_out[f_id][mask_cur.repeat(1,up_ft_out[f_id].shape[1],1,1)].view(up_ft_out[f_id].shape[1], -1).permute(1,0).mean(0, keepdim=True)
            up_ft_app_full = up_ft_app[f_id][mask_tar.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0).mean(0, keepdim=True)
            sim_full = ((cos(up_ft_cur_full, up_ft_app_full)+1.)/2.)
            
            if config.feat_guidance_type=="app_struct":
                loss_edit =  loss_edit + config.w_global/(1+4*sim_full.mean()) + config.w_struct/(1+4*sim_struct.mean()) + config.w_app/(1+4*sim_app.mean())
            else:
                loss_edit =  loss_edit + config.w_global/(1+4*sim_full.mean()) + config.w_app/(1+4*sim_all.mean())

            mask = F.interpolate(struct_mask.repeat(1,4,1,1), size=(64, 64), mode='bilinear', align_corners=False) > 0.5
            cond_grad_front = torch.autograd.grad(loss_edit*config.energy_scale, out_latent, retain_graph=True)[0]
            # torch.nn.utils.clip_grad_norm_(parameters=out_latent, max_norm=10, norm_type=2)
            if config.bg_affine and mask_app_bg.sum()!=0 and mask_struct_bg.sum()!=0:
                up_ft_out_vec_bg = up_ft_out[f_id][mask_struct_bg.repeat(1,up_ft_out[f_id].shape[1],1,1)].view(up_ft_out[f_id].shape[1], -1).permute(1,0)
                up_ft_app_vec_bg = up_ft_app[f_id][mask_app_bg.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                ft_pe_struct_vec_bg = struct_ft_add_pe[mask_struct_bg.repeat(1,up_ft_struct[f_id].shape[1],1,1)].view(up_ft_struct[f_id].shape[1], -1).permute(1,0)
                ft_pe_app_vec_bg = app_ft_add_pe[mask_app_bg.repeat(1,up_ft_app[f_id].shape[1],1,1)].view(up_ft_app[f_id].shape[1], -1).permute(1,0)
                
                affine_vector_bg = new_affine(ft_pe_app_vec_bg,ft_pe_struct_vec_bg,up_ft_app_vec_bg)
                sim_bg = ((cos(up_ft_out_vec_bg, affine_vector_bg)+1.)/2.)
                loss_bg = loss_bg + config.w_background/(1+4*sim_bg.mean())
                mask_bg = F.interpolate(struct_mask.repeat(1,4,1,1), size=(64, 64), mode='bilinear', align_corners=False) < 0.5
                cond_grad_bg = torch.autograd.grad(loss_bg*config.bg_energy_scale, out_latent, retain_graph=True)[0]
                # torch.nn.utils.clip_grad_norm_(parameters=out_latent, max_norm=10, norm_type=2)
                cond_grad_edit = cond_grad_front * mask + cond_grad_bg * mask_bg
            else:
                cond_grad_edit = cond_grad_front * mask

            cond_grad_edit = torch.clamp(cond_grad_edit, min=-1.0, max=1.0)
        guidance = cond_grad_edit.detach()
        self.estimator.zero_grad()
        return guidance
        
    def perform_ddpm_step(self, t_to_idx, zs, latents, t, noise_pred, eta):
        idx = t_to_idx[int(t)]
        z = zs[idx] if not zs is None else None
        # 1. get previous step value (=t-1)
        prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
        # 2. compute alphas, betas
        alpha_prod_t = self.scheduler.alphas_cumprod[t]
        alpha_prod_t_prev = self.scheduler.alphas_cumprod[
            prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        # 3. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        # variance = self.scheduler._get_variance(timestep, prev_timestep)
        variance = self.get_variance(t)
        std_dev_t = eta * variance ** (0.5)
        # Take care of asymetric reverse process (asyrp)
        model_output_direction = noise_pred
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
        pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
        # 8. Add noice if eta > 0
        if eta > 0:
            if z is None:
                z = torch.randn(noise_pred.shape, device=self.device)
            sigma_z = eta * variance ** (0.5) * z
            prev_sample = prev_sample + sigma_z
        return prev_sample

    def get_variance(self, timestep):
        prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.scheduler.alphas_cumprod[
            prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
        return variance
