from typing import List, Optional, Callable

import torch
import torch.nn.functional as F
from config import RunConfig
from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
from utils import attention_utils
from utils.adain import masked_adain, adain
from utils.model_utils import get_stable_diffusion_model
from utils.segmentation import Segmentor

run_config = None

class AppearanceTransferModel:

    def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
        self.config = config
        global run_config
        run_config = config
        self.pipe = get_stable_diffusion_model() if pipe is None else pipe
        self.register_attention_control()
        self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
        self.latents_app, self.latents_struct = None, None
        self.zs_app, self.zs_struct = None, None
        self.image_app_mask_32, self.image_app_mask_64 = None, None
        self.image_struct_mask_32, self.image_struct_mask_64 = None, None
        self.enable_edit = False
        self.step = 0

    def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
        self.latents_app = latents_app
        self.latents_struct = latents_struct

    def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
        self.zs_app = zs_app
        self.zs_struct = zs_struct

    def set_masks(self, masks: List[torch.Tensor]):
        self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks

    def get_adain_callback(self):

        def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
            self.step = st
            # Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
            if self.config.use_masked_adain and self.step == self.config.adain_range.start:
                masks = self.segmentor.get_object_masks()
                self.set_masks(masks)
            # Apply AdaIN operation using the computed masks
            if self.config.adain_range.start <= self.step < self.config.adain_range.end:
                if self.config.use_masked_adain:
                    latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
                else:
                    # 1: Only synthesis 2: mixture mode
                    if run_config.mix_style is True:
                        alpha = run_config.alpha
                        latents[0] = alpha*adain(latents[0], latents[1])+(1-alpha)*adain(latents[0], latents[2])
                    else:
                        latents[0] = adain(latents[0], latents[1])


        return callback

    def register_attention_control(self):

        model_self = self

        class AttentionProcessor:

            def __init__(self, place_in_unet: str):
                self.place_in_unet = place_in_unet
                if not hasattr(F, "scaled_dot_product_attention"):
                    raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")

            def __call__(self,
                         attn,
                         hidden_states: torch.Tensor,
                         encoder_hidden_states: Optional[torch.Tensor] = None,
                         attention_mask=None,
                         temb=None,
                         perform_swap: bool = False):

                residual = hidden_states

                if attn.spatial_norm is not None:
                    hidden_states = attn.spatial_norm(hidden_states, temb)

                input_ndim = hidden_states.ndim

                if input_ndim == 4:
                    batch_size, channel, height, width = hidden_states.shape
                    hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

                batch_size, sequence_length, _ = (
                    hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                )

                if attention_mask is not None:
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
                    attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

                if attn.group_norm is not None:
                    hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

                query = attn.to_q(hidden_states)

                is_cross = encoder_hidden_states is not None
                if not is_cross:
                    encoder_hidden_states = hidden_states
                elif attn.norm_cross:
                    encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

                key = attn.to_k(encoder_hidden_states)
                value = attn.to_v(encoder_hidden_states)

                inner_dim = key.shape[-1]
                head_dim = inner_dim // attn.heads
                head_kv_dim = head_dim
                should_mix = False

                # Potentially apply our cross image attention operation
                # To do so, we need to be in a self-attention layer in the decoder part of the denoising network
                attn_control = None
                if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
                    if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
                        coef = run_config.interpolation #torch.rand(1).cuda()
                        if model_self.step <0:
                            key[OUT_INDEX] = key[STYLE_INDEX]
                            value[OUT_INDEX] = value[STYLE_INDEX]

                            key[3] = key[STYLE_INDEX]
                            value[3] = value[STYLE_INDEX]
                        
                        else:

                            # sigle style
                            if run_config.mix_style is False:
                                tmp_key = torch.cat([key[OUT_INDEX],
                                                    coef*key[OUT_INDEX]+(1-coef)*key[STYLE_INDEX]],dim=0)
                                
                                tmp_value = torch.cat([value[OUT_INDEX], 
                                                       coef*value[3]+(1-coef)*value[STYLE_INDEX]],dim=0)

                                tmp_key_2 = torch.cat([key[3], 
                                                       coef*key[3]+(1-coef)*key[STYLE_INDEX]],dim=0)
                                
                                tmp_value_2 = torch.cat([value[3], 
                                                         coef*value[3]+(1-coef)*value[STYLE_INDEX]],dim=0)

                                key = torch.cat([key,key],dim=1)
                                value=torch.cat([value,value],dim=1)
                                L = key[STYLE_INDEX].shape[0]
                                attn_control = L

                            # style mixing
                            else:
                                tmp_key = torch.cat([1.0*key[OUT_INDEX],
                                                     coef*key[OUT_INDEX]+(1-coef)*key[STYLE_INDEX],
                                                     coef*key[OUT_INDEX]+(1-coef)*key[STRUCT_INDEX] ],dim=0)
                                
                                tmp_value = torch.cat([1.0*value[OUT_INDEX], 
                                                       coef*value[OUT_INDEX]+(1-coef)*value[STYLE_INDEX], 
                                                       coef*value[OUT_INDEX]+(1-coef)*value[STRUCT_INDEX]],dim=0)

                                tmp_key_2 = torch.cat([key[3], 
                                                       coef*key[3]+(1-coef)*key[STYLE_INDEX],
                                                     coef*key[3]+(1-coef)*key[STRUCT_INDEX]],dim=0)
                                tmp_value_2 = torch.cat([value[3], 
                                                        coef*value[3]+(1-coef)*value[STYLE_INDEX], 
                                                        coef*value[3]+(1-coef)*value[STRUCT_INDEX]],dim=0)

                                key = torch.cat([key,key,key],dim=1)
                                value=torch.cat([value,value,value],dim=1)
                                L = key[STYLE_INDEX].shape[0]
                                attn_control = L

                            key[OUT_INDEX] = tmp_key
                            value[OUT_INDEX] = tmp_value
                            key[3] = tmp_key_2
                            value[3] = tmp_value_2
                            #query[OUT_INDEX] = 0.04*query[STRUCT_INDEX] + 0.96*query[OUT_INDEX]#+0.2*query[STYLE_INDEX]

                query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
                key = key.view(batch_size, -1, attn.heads, head_kv_dim).transpose(1, 2)
                value = value.view(batch_size, -1, attn.heads, head_kv_dim).transpose(1, 2)

                # Compute the cross attention and apply our contrasting operation
                hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
                    query, key, value,
                    edit_map=perform_swap and model_self.enable_edit and should_mix,
                    is_cross=is_cross,
                    contrast_strength=model_self.config.contrast_strength,
                    attn_control=attn_control
                )

                # Update attention map for segmentation
                if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
                    model_self.segmentor.update_attention(attn_weight, is_cross)

                hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
                hidden_states = hidden_states.to(query[OUT_INDEX].dtype)

                # linear proj
                hidden_states = attn.to_out[0](hidden_states)
                # dropout
                hidden_states = attn.to_out[1](hidden_states)

                if input_ndim == 4:
                    hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

                if attn.residual_connection:
                    hidden_states = hidden_states + residual

                hidden_states = hidden_states / attn.rescale_output_factor

                return hidden_states

        def register_recr(net_, count, place_in_unet):
            if net_.__class__.__name__ == 'ResnetBlock2D':
                pass
            if net_.__class__.__name__ == 'Attention':
                net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
                return count + 1
            elif hasattr(net_, 'children'):
                for net__ in net_.children():
                    count = register_recr(net__, count, place_in_unet)
            return count

        cross_att_count = 0
        sub_nets = self.pipe.unet.named_children()
        for net in sub_nets:
            if "down" in net[0]:
                cross_att_count += register_recr(net[1], 0, "down")
            elif "up" in net[0]:
                cross_att_count += register_recr(net[1], 0, "up")
            elif "mid" in net[0]:
                cross_att_count += register_recr(net[1], 0, "mid")
