import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import sys
from typing import Any, Optional, Tuple, Union
sys.path.append("/home/liqingyu/xianxiaole/ConsistentID/LE-Adapter")
sys.path.append("/home/liqingyu/xianxiaole/ConsistentID/ControlNetPlus")
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.cm as cm

from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, Blip2QFormerModel
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerEncoder
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions

from ip_adapter.ip_adapter_faceid import MLPProjModel
from ip_adapter.utils import is_torch2_available

from diffusers.models.attention import Attention

def visual_attn_map(vis_map, vis_token_index):
    vis_map_list = []
    if len(vis_map.shape) == 4:
        
        Resolution_vis = int(vis_map.shape[-2]**0.5)
        vis = vis_map[1].sum(dim=0).transpose(0,1).view(-1,Resolution_vis, Resolution_vis)[vis_token_index].cpu().numpy()
        vis = ((vis-vis.min())/(vis.max() - vis.min()) *255).astype(np.uint8)

    colormap = cm.get_cmap('rainbow')
    rainbow_img_array = colormap(vis / 255.0)
    rainbow_img_array = (rainbow_img_array[:, :, :3] * 255).astype(np.uint8)
    
    return Image.fromarray(rainbow_img_array).resize((512,512)).convert("RGB")

def visual_rgb_img(img):

    rgb_img = img.detach().cpu().numpy()
    if len(img.shape) == 3:
        rgb_img = Image.fromarray(((rgb_img-rgb_img.min())/(rgb_img.max() - rgb_img.min()) *255).astype(np.uint8))
    elif len(img.shape) == 2:
        rgb_img = ((rgb_img-rgb_img.min())/(rgb_img.max() - rgb_img.min()) *255).astype(np.uint8)

        colormap = cm.get_cmap('rainbow')
        rainbow_img_array = colormap(rgb_img / 255.0)
        rainbow_img_array = (rainbow_img_array[:, :, :3] * 255).astype(np.uint8)
        rgb_img = Image.fromarray(rainbow_img_array).resize((512,512)).convert("RGB")

    return rgb_img

class TransformerProjmodel(Attention):
    def __init__(self, query_dim = None):
        super().__init__(query_dim = query_dim)
        

class TextProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=768, clip_embeddings_dim=768, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class MLPProjModel(torch.nn.Module):
    """SD model with image prompt"""
    def __init__(self, cross_attention_dim=768, clip_embeddings_dim=768):
        super().__init__()
        
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
            torch.nn.GELU(),
            torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
            torch.nn.LayerNorm(cross_attention_dim)
        )
        
    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class TPPortrait_Adapter(torch.nn.Module):
    """IP-Adapter"""
    def __init__(self, unet, cross_attention_dim=768, ckpt_path=None, pool_token = False):
        super().__init__()
        self.unet = unet
        # self.image_proj_model = image_proj_model ## TODO: replace MLP with Q-former's learnable querys
        if pool_token:
            self.ShapeText_Projection = TextProjModel(cross_attention_dim, cross_attention_dim)
        else:
            self.ShapeText_Projection = MLPProjModel(cross_attention_dim, cross_attention_dim)
        self.adapter_modules = torch.nn.ModuleList(self.unet.attn_processors.values())
        
        if ckpt_path is not None:
            self.load_from_checkpoint(ckpt_path)

    def forward(self, 
        noisy_latents,
        timesteps,
        encoder_hidden_states,
        text_shape_embeds,
        down_block_res_samples=None,
        mid_block_res_sample=None,
        ):
        
        text_shape_tokens = self.ShapeText_Projection(text_shape_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, text_shape_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(
            noisy_latents, 
            timesteps,
            encoder_hidden_states,
            down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples] if down_block_res_samples is not None else None,
            mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype) if mid_block_res_sample is not None else None,
            return_dict = False,
            )[0]
        
        return noise_pred

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ShapeText_Projection.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for image_proj_model and adapter_modules
        self.ShapeText_Projection.load_state_dict(state_dict["ShapeText_Projection"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["TP_model"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ShapeText_Projection.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")

    def finegrained_encoder_setting(self, text_encoder, tokenizer):
        self.TP_Encoder = text_encoder
        self.TP_tokenizer = tokenizer

    def generate(
            self,
            pipe,
            faceid_embeds=None,
            global_prompt=None,
            finegrained_prompt=None,
            negative_prompt=None,
            scale=1.0,
            num_samples=1,
            all_tokens=True,
            seed=None,
            generator=None,
            guidance_scale=7.5,
            num_inference_steps=30,
            s_scale=1.0,
            **kwargs,
        ):

        with torch.inference_mode():
            global_prompt_embeds_, global_negative_prompt_embeds_ = pipe.encode_prompt(
                global_prompt,
                device=pipe.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )

            assert (hasattr(self, "TP_Encoder") and hasattr(self, "TP_tokenizer")), "Enocder not exisit!"
            
            if all_tokens:
                finegrained_embeds_ = self.TP_Encoder.encode_text(self.TP_tokenizer(finegrained_prompt))[1]
                finegrained_negative_prompt_embeds_ = self.TP_Encoder.encode_text(self.TP_tokenizer(""))[1]
            else:
                finegrained_embeds_ = self.TP_Encoder.encode_text(self.TP_tokenizer(finegrained_prompt))[0]
                finegrained_negative_prompt_embeds_ = self.TP_Encoder.encode_text(self.TP_tokenizer(""))[0]
            
            finegrained_embeds_ = self.ShapeText_Projection(finegrained_embeds_.to(pipe.device, global_prompt_embeds_.dtype))
            finegrained_negative_prompt_embeds_ = self.ShapeText_Projection(finegrained_negative_prompt_embeds_.to(pipe.device, global_prompt_embeds_.dtype))

            prompt_embeds = torch.cat([global_prompt_embeds_, finegrained_embeds_], dim=1)
            negative_prompt_embeds = torch.cat([global_negative_prompt_embeds_, finegrained_negative_prompt_embeds_], dim=1)

        images = pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            # num_images_per_prompt = num_samples,
            **kwargs,
        ).images
        
        return images

def TPFace_set_attnprocessor(unet,lora_rank, num_tokens = 77):
    unet_sd = unet.state_dict()
    attn_procs = {}
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
        else:
            layer_name = name.split(".processor")[0]
            weights = {
                "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
                "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
            }
            attn_procs[name] = LoRAIPAttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens = num_tokens,rank=lora_rank)
            attn_procs[name].load_state_dict(weights, strict=False)
    unet.set_attn_processor(attn_procs)


class AttnProcessor2_0(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # with torch.no_grad():
        #     self.self_attn_map = ((query @ key.transpose(-2, -1)).softmax(dim=-1)).transpose(-2,-1).sum(dim=1)

        self.selfattn_query = query

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


class LoRAIPAttnProcessor2_0(nn.Module):
    r"""
    Processor for implementing the LoRA attention mechanism.

    Args:
        hidden_size (`int`, *optional*):
            The hidden size of the attention layer.
        cross_attention_dim (`int`, *optional*):
            The number of channels in the `encoder_hidden_states`.
        rank (`int`, defaults to 4):
            The dimension of the LoRA update matrices.
        network_alpha (`int`, *optional*):
            Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
    """

    def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4):
        super().__init__()
        
        self.rank = rank
        self.lora_scale = lora_scale
        self.num_tokens = num_tokens
        
        # self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        # self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        # self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
        # self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
        
        
        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.scale = scale

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

    def __call__(
        self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args, **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        else:
            # get encoder_hidden_states, ip_hidden_states
            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
            if end_pos >= 77:
                encoder_hidden_states, ip_hidden_states = (
                    encoder_hidden_states[:, :end_pos, :],
                    encoder_hidden_states[:, end_pos:, :],
                )
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        query = attn.to_q(hidden_states)

        # for text
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        self.crosskey = key
        self.crossquery = query

        # self.QKQ_attn = F.scaled_dot_product_attention(
        #     key, query, query, attn_mask=None, dropout_p=0.0, is_causal=False
        # )

        # if end_pos >= 77 and hasattr(self, "query"):
        #     self.query_loss = F.mse_loss(query.float(), self.query.float(), reduction ='none')
        # else:
        #     self.query = query.clone().detach()
        
        self.attn_map_global = (query @ key.transpose(-2, -1)).softmax(dim=-1)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)
        
        if end_pos >= 77:
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)
            
            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            
            # with torch.no_grad():
            self.attn_map_TP = (query @ ip_key.transpose(-2, -1)).softmax(dim=-1).clone().detach()

            # the output of sdp = (batch, num_heads, seq_len, head_dim)
            # TODO: add support for attn.scale when we move to Torch 2.1
            ip_hidden_states = F.scaled_dot_product_attention(
                query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            
            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
            ip_hidden_states = ip_hidden_states.to(query.dtype)
            
            # self.shape_token = ip_hidden_states# 

            hidden_states = hidden_states + self.scale * ip_hidden_states

        # linear proj
        hidden_states = attn.to_out[0](hidden_states) # + self.lora_scale * self.to_out_lora(hidden_states) if end_pos > 0 else attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
