from typing import List
import torch
from torchvision import transforms
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
import torch.nn.functional as F
import torch.nn as nn
import os

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

if is_torch2_available():
    from .attention_processor import SCAttnProcessor2_0 as SCAttnProcessor, AttnProcessor2_0 as AttnProcessor, adapterprocessor
    from .attention_processor import Multi_SCAttnProcessor2_0 as Multi_SCAttnProcessor
else:
    from .attention_processor import SCAttnProcessor, AttnProcessor

from .attention_processor import AttnProcessor_map, AttnProcessor_mask, adapterprocsser1_0 as adapterprocessor

from .attention_mul_ip import MixIT_AttnProcessor, KiVt_AttnProcessor, TMaskIP_AttnProcessor

from diffusers.models.attention_processor import Attention


class adapter(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = Attention(
            query_dim=768,
            cross_attention_dim=1024,
            heads=8,
            dim_head=64,
            dropout=0.,
        )
        self.norm = torch.nn.LayerNorm(768)
        self.attn.to_v = None
        self.attn.to_out = None  # for paral train

    def forward(self, text_embeds, image_embeds):
        clip_extra_context_tokens = self.attn(text_embeds, encoder_hidden_states=image_embeds)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class SC_Adapter(torch.nn.Module):
    """SC-Adapter"""
    def __init__(self, unet, image_encoder_path, device="cuda", dtype=torch.float32):
        super().__init__()
        self.device = device
        self.dtype = dtype

        # load image encoder
        self.image_encoder = CLIPVisionModel.from_pretrained(image_encoder_path).to(self.device, dtype=self.dtype)
        self.clip_image_processor = CLIPImageProcessor()

        # load SC layers
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = AttnProcessor()
            else:
                # attn_procs[name] = SCAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                #                                    scale=1, use_orig_kv=False).to(self.device, dtype=self.dtype)
                attn_procs[name] = TMaskIP_AttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
                                                   scale=1, use_orig_kv=False).to(self.device, dtype=self.dtype)
        unet.set_attn_processor(attn_procs)
        adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

        # load SC adapter
        self.sc_adapter = adapter()
        adapter_procs = adapterprocessor(query_dim=768, inner_dim=512, cross_attention_dim=1024)
        self.sc_adapter.attn.set_processor(adapter_procs)

        self.sc_layers = adapter_modules
        self.sc_adapter.to(self.device, dtype=self.dtype)
        self.sc_layers.to(self.device, dtype=self.dtype)

    def load_sc_adapter(self, adapter_path, layers_path):
        state_dict1 = torch.load(adapter_path, map_location=self.device)
        self.sc_adapter.load_state_dict(state_dict1)
        state_dict0 = torch.load(layers_path, map_location=self.device)
        self.sc_layers.load_state_dict(state_dict0)

    def load_ori_sc_adapter(self, adapter_path,layers_path):
        state_dict1 = torch.load(adapter_path, map_location=self.device)
        self.sc_adapter.load_state_dict(state_dict1, strict=False)
        state_dict0 = torch.load(layers_path, map_location=self.device)
        self.sc_layers.load_state_dict(state_dict0)

    def load_encoder(self, path):
        self.image_encoder.load_state_dict(torch.load(os.path.join(path, "pytorch_model.bin")
                                           , map_location=self.device), strict=False)

    def forward(self, encoder_hidden_states, image_embeds):
        encoder_hidden_states = self.sc_adapter(text_embeds=encoder_hidden_states, image_embeds=image_embeds)
        return encoder_hidden_states

    def get_pipe(self, pipe):
        self.pipe = pipe

    def set_scale(self, scale):
        for attn_processor in self.pipe.unet.attn_processors.values():
            if isinstance(attn_processor, SCAttnProcessor):
                attn_processor.scale = scale

    def generate(
            self,
            pil_image,
            concept,
            uncond_concept=" ",
            prompt=" ",
            negative_prompt=" ",
            scale=1.0,
            num_samples=1,
            seed=None,
            guidance_scale=7.5,
            num_inference_steps=30,
            is_style=True,
            height=640,
            width=640,
    ):
        self.set_scale(scale)

        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image, num_samples, concept, uncond_concept, is_style
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        images = self.pipe(
            height=height,
            width=width,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images
        return images

    def generate_control(
            self,
            pil_image,
            control_img,
            concept,
            uncond_concept=" ",
            prompt=" ",
            negative_prompt=" ",
            scale=1.0,
            num_samples=1,
            seed=None,
            guidance_scale=7.5,
            num_inference_steps=30,
            is_style=True,
            height=640,
            width=640,
    ):
        self.set_scale(scale)

        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image, num_samples, concept, uncond_concept, is_style
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None

        images = self.pipe(
            height=height,
            width=width,
            image=control_img,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            controlnet_conditioning_scale=1.,
        ).images
        return images


    def generate_mask(
            self,
            mask,
            pil_image,
            concept,
            uncond_concept=" ",
            prompt=" ",
            negative_prompt=" ",
            scale=1.0,
            num_samples=1,
            seed=None,
            guidance_scale=7.5,
            num_inference_steps=30,
            is_style=True,
            height=640,
            width=640,
    ):
        self.set_scale(scale)
        self.mask_trans = transforms.ToTensor()

        if isinstance(pil_image, Image.Image):
            num_prompts = 1
        else:
            num_prompts = len(pil_image)

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_mask(
            pil_image, num_samples, concept, uncond_concept, is_style, mask
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        images = self.pipe(
            height=height,
            width=width,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images
        return images

    def generate_multi(
            self,
            pil_image_list,
            concept_list,
            uncond_concept=[],
            prompt=[],
            negative_prompt=[],
            scale=[],
            num_samples=1,
            seed=None,
            guidance_scale=7.5,
            num_inference_steps=30,
            is_style=True,
            height=640,
            width=640,
    ):

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image_list, num_samples, concept_list, uncond_concept, is_style
        )
        image_prompt_embeds = image_prompt_embeds.view(1, -1, 768)
        cache = []
        for i in range(len(scale)):
            image_prompt_embed_new = scale[i] * image_prompt_embeds[:, int(77*6 * i): int(77*6 + 77*6 * i),:].clone()
            image_prompt_embed = image_prompt_embed_new.clone()
            cache.append(image_prompt_embed)
        image_prompt_embeds = torch.cat(cache, dim=1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(1, -1, 768)

        with torch.inference_mode():
            prompt_embeds = self.pipe._encode_prompt(
                prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
                negative_prompt=negative_prompt)
            negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        images = self.pipe(
            height=height,
            width=width,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images
        return images

    def get_attnmaps(
            self,
            pil_image,
            concept,
            name,
    ):

        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = []
        for pil in pil_image:
            tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device,
                                                                                                      dtype=self.dtype)
            clip_image.append(tensor_image)
        clip_image = torch.cat(clip_image, dim=0)

        # text
        prompt_embeds = self.pipe._encode_prompt(
            concept, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True,
            negative_prompt="")
        negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)

        # cond
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][4::4]
        self.sc_adapter.attn.set_processor(AttnProcessor_map(name))
        self.sc_adapter(
            prompt_embeds_,
            clip_image_embeds
        )
        return

    @torch.inference_mode()
    def get_image_embeds(self, pil_image, num_samples, text, uncond_text, is_style):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = []
        for pil in pil_image:
            # tensor_image = self.T(pil).to(self.device, dtype=self.dtype).unsqueeze(0)
            tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
            # clip_image.append(tensor_image * 0.5 + 0.5)
            clip_image.append(tensor_image)
        clip_image = torch.cat(clip_image, dim=0)

        # text
        prompt_embeds = self.pipe._encode_prompt(
            text, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
            negative_prompt=uncond_text)
        negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)

        # cond
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][4::4]
        image_prompt_embeds = self.sc_adapter(
            prompt_embeds_,
            clip_image_embeds
        )

        # uncond
        if is_style:
            uncond_clip_image_embeds = self.image_encoder(torch.ones_like(clip_image), output_hidden_states=True)['hidden_states'][4::4]
            uncond_image_prompt_embeds = self.sc_adapter(
                negative_prompt_embeds_,
                uncond_clip_image_embeds
            )
        else:
            uncond_clip_image_embeds = \
            self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True)['hidden_states'][4::4]
            uncond_image_prompt_embeds = self.sc_adapter(
                negative_prompt_embeds_,
                uncond_clip_image_embeds
            )

        return image_prompt_embeds, uncond_image_prompt_embeds

    @torch.inference_mode()
    def get_image_embeds_mask(self, pil_image, num_samples, text, uncond_text, is_style, masks):
        if isinstance(pil_image, Image.Image):
            pil_image = [pil_image]
        clip_image = []
        for pil in pil_image:
            # tensor_image = self.T(pil).to(self.device, dtype=self.dtype).unsqueeze(0)
            # clip_image.append(tensor_image * 0.5 + 0.5)
            tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
            clip_image.append(tensor_image)
        clip_image = torch.cat(clip_image, dim=0)

        if isinstance(masks, Image.Image):
            masks = [masks]
        mask_image = []
        for mask in masks:
            tensor_image = self.mask_trans(mask).unsqueeze(0)
            tensor_image = torch.where(tensor_image > 0.5, torch.ones_like(tensor_image), torch.zeros_like(tensor_image))
            mask_image.append(tensor_image)
        mask_image = torch.cat(mask_image, dim=0)

        # text
        prompt_embeds = self.pipe._encode_prompt(
            text, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True,
            negative_prompt=uncond_text)
        negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)

        # cond
        clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][4::4]

        from .attention_processor import get_attention_scores_mask
        self.sc_adapter.attn.get_attention_scores_mask = get_attention_scores_mask.__get__(self.sc_adapter.attn)
        self.sc_adapter.attn.set_processor(AttnProcessor_mask(mask_image))

        image_prompt_embeds = self.sc_adapter(
            prompt_embeds_,
            clip_image_embeds
        )

        # uncond
        if is_style:
            uncond_clip_image_embeds = \
            self.image_encoder(torch.ones_like(clip_image), output_hidden_states=True)['hidden_states'][4::4]
            uncond_image_prompt_embeds = self.sc_adapter(
                negative_prompt_embeds_,
                uncond_clip_image_embeds
            )
        else:
            uncond_clip_image_embeds = \
                self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True)['hidden_states'][4::4]
            uncond_image_prompt_embeds = self.sc_adapter(
                negative_prompt_embeds_,
                uncond_clip_image_embeds
            )

        return image_prompt_embeds, uncond_image_prompt_embeds