import torch
import torch.nn.functional as F

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

if is_torch2_available():
    from modules.attention_processor import (
        AttnProcessor2_0 as AttnProcessor,
    )
    from modules.attention_processor import (
        IPAttnProcessor2_0 as IPAttnProcessor,
    )
    from modules.attention_processor_decoupled import IPAttnProcessorDecoupled2_0 as IPAttnProcessorDecoupled
else:
    from modules.attention_processor import AttnProcessor, IPAttnProcessor
    from modules.attention_processor_decoupled import IPAttnProcessorDecoupled

from modules.image_projection import MLPProjModel, Resampler, ImageProjection, ProjPlusModel, MLPProjModelCLIP


def init_proj(instance):
    if instance.ip_mode in ['faceid', 'faceid-lora', 'portrait']:

        image_proj_model = MLPProjModel(
            cross_attention_dim=instance.unet.config.cross_attention_dim,
            id_embeddings_dim=512,
            num_tokens=instance.num_tokens,
        ).to(instance.device, dtype=instance.vae.dtype)

        return image_proj_model

    elif instance.ip_mode in ['full_face']:
        image_proj_model = MLPProjModelCLIP(
            cross_attention_dim=instance.unet.config.cross_attention_dim,
            clip_embeddings_dim=instance.image_encoder.config.hidden_size,
        ).to(instance.device, dtype=instance.vae.dtype)
        return image_proj_model

    elif instance.ip_mode in ['plus', 'vanilla']:
        if instance.ip_mode == 'plus':
            img_projector = Resampler
        elif instance.ip_mode == 'vanilla':
            img_projector = ImageProjection

        image_proj_model = img_projector(
            dim=instance.unet.config.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=12,
            num_queries=instance.num_tokens,
            embedding_dim=instance.image_encoder.config.hidden_size,
            output_dim=instance.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(instance.device, dtype=instance.vae.dtype)
        return image_proj_model
    elif instance.ip_mode in ['faceid-plus', 'faceid-plus-lora']:

        image_proj_model = ProjPlusModel(
            cross_attention_dim=instance.unet.config.cross_attention_dim,
            id_embeddings_dim=512,
            clip_embeddings_dim=instance.image_encoder.config.hidden_size,
            num_tokens=instance.num_tokens,
        ).to(instance.device, dtype=instance.vae.dtype)
        return image_proj_model
    elif instance.ip_mode in ['faceid-decoupled']:

        image_proj_model_1 = MLPProjModel(
            cross_attention_dim=instance.unet.config.cross_attention_dim,
            id_embeddings_dim=512,
            num_tokens=instance.num_tokens[0],
        ).to(instance.device, dtype=instance.vae.dtype)

        image_proj_model_2 = Resampler(
            dim=instance.unet.config.cross_attention_dim,
            depth=instance.resampler_depth,
            dim_head=64,
            heads=12,
            num_queries=instance.num_tokens[1],
            embedding_dim=instance.image_encoder.config.hidden_size,
            output_dim=instance.unet.config.cross_attention_dim,
            ff_mult=4,
        ).to(instance.device, dtype=instance.vae.dtype)

        return image_proj_model_1, image_proj_model_2

    else:
        raise ValueError('Not supported type of IP-Adpater Image Projection Model!')

def set_ip_adapter(instance, device, store_attn=False, hook_attn_key=None, store_qk=False, hook_qk_key=None):
    attn_procs = {}
    hook_count = 0
    for name in instance.unet.attn_processors.keys():
        ipa_keyname = "attn1_7.processor"
        if name.endswith(f"{hook_attn_key}.processor") and store_attn:
            attn_procs[name] = AttnProcessor(store_attn_maps=store_attn)
            hook_count += 1
        elif name.endswith(f"{hook_qk_key}.processor") and store_qk:
            attn_procs[name] = AttnProcessor(store_qk=store_qk)
            hook_count += 1
        elif not name.endswith(ipa_keyname):
            attn_procs[name] = AttnProcessor()
        else:
            cross_attention_dim = instance.unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = instance.unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(instance.unet.config.block_out_channels))[
                    block_id
                ]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = instance.unet.config.block_out_channels[block_id]

            if instance.ip_mode in ['faceid-decoupled']:
                attn_procs[name] = IPAttnProcessorDecoupled(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    scale=[0.7, 0.3],
                    num_tokens=instance.num_tokens * instance.n_cond,
                ).to(device, dtype=torch.float16)

                # Freeze the parameters in the to_k_ip and to_v_ip layers
                for param in attn_procs[name].to_k_ip_1.parameters():
                    param.requires_grad = False
                for param in attn_procs[name].to_v_ip_1.parameters():
                    param.requires_grad = False
                for param in attn_procs[name].to_k_ip_2.parameters():
                    param.requires_grad = False
                for param in attn_procs[name].to_v_ip_2.parameters():
                    param.requires_grad = False

            else:
                attn_procs[name] = IPAttnProcessor(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    scale=1.0,
                    num_tokens=instance.num_tokens * instance.n_cond,
                ).to(device, dtype=torch.float16)

                # Freeze the parameters in the to_k_ip and to_v_ip layers
                for param in attn_procs[name].to_k_ip.parameters():
                    param.requires_grad = False
                for param in attn_procs[name].to_v_ip.parameters():
                    param.requires_grad = False

    if hook_count > 0:
        print(f"Set {hook_count} {hook_attn_key}s with Attn Storing")
    instance.unet.set_attn_processor(attn_procs)

def load_ip_adapter(instance):
    print(f"[INFO] loading IP-Adapter checkpoints from {instance.ip_ckpt}...")
    state_dict = torch.load(instance.ip_ckpt, map_location="cpu")

    if instance.ip_mode in ['faceid-decoupled']:
        instance.image_proj_model_1.load_state_dict(state_dict["image_proj_1"])
        instance.image_proj_model_2.load_state_dict(state_dict["image_proj_2"])
    else:
        instance.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)

    # need mapping to match the new model
    # Step 1: Create a mapping
    if instance.ip_mode in ['faceid-decoupled']:
        ip_attn_name = (
            "IPAttnProcessorDecoupled2_0" if is_torch2_available() else "IPAttnProcessorDecoupled"
        )
    else:
        ip_attn_name = (
            "IPAttnProcessor2_0" if is_torch2_available() else "IPAttnProcessor"
        )
    ip_attns = {
        idx: i
        for idx, i in enumerate(list(instance.unet.attn_processors.values()))
        if ip_attn_name in str(i)
    }
    state_id_to_model_pos_mapping = {
        idx + 1: list(ip_attns.keys())[pos]
        for pos, idx in enumerate(range(0, len(ip_attns) * 2, 2))
    }
    # Step 2: Modify keys in the state_dict
    new_state_dict = {}
    for k, v in state_dict["ip_adapter"].items():
        # Extract the original ID from the key (assuming format 'x.to_k_ip.weight' or 'x.to_v_ip.weight')
        original_id = int(k.split(".")[0])
        # Map the original ID to the position in the model
        mapped_pos = state_id_to_model_pos_mapping[original_id]
        # Calculate the new key based on the position in the model
        if instance.ip_mode in ['faceid-decoupled']:
            if "to_k_ip_1" in k:
                new_key = f"{mapped_pos}.to_k_ip_1.weight"
            elif "to_v_ip_1" in k:
                new_key = f"{mapped_pos}.to_v_ip_1.weight"
            elif "to_k_ip_2" in k:
                new_key = f"{mapped_pos}.to_k_ip_2.weight"
            elif "to_v_ip_2" in k:
                new_key = f"{mapped_pos}.to_v_ip_2.weight"
            else:
                raise ValueError("Unexpected key format in state_dict.")
        else:
            if "to_k_ip" in k:
                new_key = f"{mapped_pos}.to_k_ip.weight"
            elif "to_v_ip" in k:
                new_key = f"{mapped_pos}.to_v_ip.weight"
            else:
                raise ValueError("Unexpected key format in state_dict.")
        # Assign the value to the new key
        new_state_dict[new_key] = v
    # Now, `new_state_dict` should have keys that align with your model's structure.
    # You can then load this new state_dict into your model.
    ip_layers = torch.nn.ModuleList(instance.unet.attn_processors.values())
    ip_layers.load_state_dict(new_state_dict, strict=True)
    print(f"[INFO] loaded IP-Adapter checkpoints...")

def set_ipa_scale(instance, scale):
    idx = 0
    for attn_processor in instance.unet.attn_processors.values():
        if isinstance(attn_processor, IPAttnProcessor) or isinstance(attn_processor, IPAttnProcessorDecoupled):
            attn_processor.scale = scale
            idx += 1
    print(f"[INFO] Set {idx} IPAttnProcessor(s) scale to {scale}")
