from typing import Optional
from collections import OrderedDict
from diffusers import MotionAdapter
from safetensors import safe_open

from .convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint

def load_motion_modules(unet, motion_adapter: Optional[MotionAdapter]) -> None:
    for i, down_block in enumerate(motion_adapter.down_blocks):
        unet.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict(), strict=True)
    for i, up_block in enumerate(motion_adapter.up_blocks):
        unet.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict(), strict=True)

    unet.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict(), strict=True)

def t2i_adapter_map_keys(state_dict, num_res_blocks=2):
    '''
    Usage:
        checkpoint_path = 'path_to_checkpoint.pth'
        state_dict = torch.load(checkpoint_path)
        new_state_dict = t2i_adapter_map_keys(state_dict)
        t2iadapter = T2IAdapter(config)
        t2iadapter.load_state_dict(new_state_dict, strict=False)
    '''
    new_state_dict = {}
    for key in state_dict.keys():
        if key.startswith("body."):
            parts = key.split(".")
            body_idx = int(parts[1])
            block_type = parts[2]
            param = parts[3]

            # Determine resnet layer and new body index
            resnet_idx = body_idx // num_res_blocks
            new_body_idx = body_idx % num_res_blocks
            
            if 'in_conv' in block_type:
                new_key = f"adapter.body.{resnet_idx}.{block_type}.{param}"
            else:
                new_key = f"adapter.body.{resnet_idx}.resnets.{new_body_idx}.{block_type}.{param}"

            new_state_dict[new_key] = state_dict[key]

        elif key.startswith("conv_in"):
            new_state_dict["adapter.conv_in." + key.split(".")[-1]] = state_dict[key]
        else:
            new_state_dict[key] = state_dict[key]
    
    return new_state_dict

def remap_state_dict(state_dict, mapping):
    new_state_dict = {}
    for old_key, new_key in mapping.items():
        if old_key in state_dict:
            new_state_dict[new_key] = state_dict[old_key]
    return new_state_dict

# Function to find different keys and align similar ones for comparison
def compare_state_dicts(dict1, dict2, dict1_name, dict2_name):
    '''
    Usage:
        state_dict_adapter = adapter.down_blocks[0].motion_modules.state_dict()
        state_dict_denoising_unet = denoising_unet.down_blocks[0].motion_modules.state_dict()
        compare_state_dicts(state_dict_adapter, state_dict_denoising_unet, dict1_name='Adapter', dict2_name='UNet-Motion')
    '''
    dict1_keys = set(dict1.keys())
    dict2_keys = set(dict2.keys())
    
    common_keys = dict1_keys.intersection(dict2_keys)
    unique_to_dict1 = dict1_keys - common_keys
    unique_to_dict2 = dict2_keys - common_keys
    
    # Print keys side by side with differences
    print("Adapter Keys\t|\tModified Denoising UNet Keys")
    print("-" * 40)
    
    # Print matching keys
    for key in sorted(common_keys):
        print(f"{key}\t|\t{key}")
    
    # Indicate unique keys in each dict
    for unique_key_set, name in zip([unique_to_dict1, unique_to_dict2], ["Adapter Only", "Modified Denoising UNet Only"]):
        if unique_key_set:
            print(f"\n{name} Keys:")
            for key in sorted(unique_key_set):
                print(key)

def compare_state_dicts(state_dict1, state_dict2):
    keys1 = set(state_dict1.keys())
    keys2 = set(state_dict2.keys())
    
    # Check for missing keys
    missing_keys1 = keys2 - keys1
    missing_keys2 = keys1 - keys2
    
    if missing_keys1 or missing_keys2:
        print("The keys do not match between the two state dictionaries.")
        if missing_keys1:
            print(f"Keys in state_dict2 but not in state_dict1: {missing_keys1}")
        if missing_keys2:
            print(f"Keys in state_dict1 but not in state_dict2: {missing_keys2}")
        return False
    
    # Check for value mismatches
    value_mismatches = {}
    for key in keys1:
        if not (state_dict1[key].shape == state_dict2[key].shape):
            value_mismatches[key] = (state_dict1[key], state_dict2[key])
    
    if value_mismatches:
        print("Some keys have mismatched values between the two state dictionaries.")
        for key, (val1, val2) in value_mismatches.items():
            print(f"Mismatch in key: {key}")
            print(f"Value in state_dict1 {val1.shape} | state_dict2 {val2.shape}")
        return False
    
    print("The keys and values match between the two state dictionaries.")
    return True

def load_dreambooth_weights(pipeline, dreambooth_model_path, dtype, device):
    if dreambooth_model_path != "":
        print(f"[INFO] loading dreambooth model from {dreambooth_model_path}....")
        dreambooth_state_dict = {}
        with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                dreambooth_state_dict[key] = f.get_tensor(key)

            converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
            # vae ->to_q,to_k,to_v
            convert_vae_keys = list(converted_vae_checkpoint.keys())
            for key in convert_vae_keys:
                if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in  key:
                    new_key = None
                    if "key" in key:
                        new_key = key.replace("key","to_k")
                    elif "query" in key:
                        new_key = key.replace("query","to_q")
                    elif "value" in key:
                        new_key = key.replace("value","to_v")
                    elif "proj_attn" in key:
                        new_key = key.replace("proj_attn","to_out.0")
                    if new_key:
                        converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)

            pipeline.vae.load_state_dict(converted_vae_checkpoint)
            pipeline.vae.to(device, dtype=dtype)

            converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)

            converted_unet_checkpoint_new = OrderedDict()
            for key, value in converted_unet_checkpoint.items():
                if 'attn2' in key:
                    new_key = key.replace('attn2', 'attn1_7')
                    converted_unet_checkpoint_new[new_key] = value
                else:
                    converted_unet_checkpoint_new[key] = value

            # Now replace the original checkpoint with the new one
            converted_unet_checkpoint = converted_unet_checkpoint_new

            # convert_unet_keys = list(converted_unet_checkpoint.keys())
            # for key in convert_unet_keys:
            #     if 'attn2' in key: # 'conv_in', 'down_blocks'
            #         converted_unet_checkpoint.pop(key)
                # if 'down_blocks.0' in key or 'down_blocks.1' in key: # 'conv_in', 'down_blocks'
                #     converted_unet_checkpoint.pop(key)
                # if 'attn1' in key or 'norm_1' in key: # 'conv_in', 'down_blocks'
                #     converted_unet_checkpoint.pop(key)

            m, u = pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
            print(f'###### Missing keys: {m}')
            pipeline.unet.to(device, dtype=dtype)

            pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
            pipeline.text_encoder.to(device, dtype=dtype)

            print(f"[INFO] Loaded dreambooth model from {dreambooth_model_path}....")

        return pipeline
