import argparse
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

def is_unet_key(key):
    # VAE or TextEncoder, the last one is for SDXL
    return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)


TEXT_ENCODER_KEY_REPLACEMENTS = [
    ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
    ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
    ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
]


# support for models with different text encoder keys
def replace_text_encoder_key(key):
    for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
        if key.startswith(rep_from):
            return True, rep_to + key[len(rep_from) :]
    return False, key


def merge(args):
    if args.precision == "fp16":
        dtype = torch.float16
    elif args.precision == "bf16":
        dtype = torch.bfloat16
    else:
        dtype = torch.float

    if args.saving_precision == "fp16":
        save_dtype = torch.float16
    elif args.saving_precision == "bf16":
        save_dtype = torch.bfloat16
    else:
        save_dtype = torch.float

    # check if all models are safetensors
    for model in args.models:
        if not model.endswith("safetensors"):
            logger.info(f"Model {model} is not a safetensors model")
            exit()
        if not os.path.isfile(model):
            logger.info(f"Model {model} does not exist")
            exit()

    assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"

    # load and merge
    ratio = 1.0 / len(args.models)  # default
    supplementary_key_ratios = {}  # [key] = ratio, for keys not in all models, add later

    merged_sd = None
    first_model_keys = set()  # check missing keys in other models
    for i, model in enumerate(args.models):
        if args.ratios is not None:
            ratio = args.ratios[i]

        if merged_sd is None:
            # load first model
            logger.info(f"Loading model {model}, ratio = {ratio}...")
            merged_sd = {}
            with safe_open(model, framework="pt", device=args.device) as f:
                for key in tqdm(f.keys()):
                    value = f.get_tensor(key)
                    _, key = replace_text_encoder_key(key)

                    first_model_keys.add(key)

                    if not is_unet_key(key) and args.unet_only:
                        supplementary_key_ratios[key] = 1.0  # use first model's value for VAE or TextEncoder
                        continue

                    value = ratio * value.to(dtype)  # first model's value * ratio
                    merged_sd[key] = value

            logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
            continue

        # load other models
        logger.info(f"Loading model {model}, ratio = {ratio}...")

        with safe_open(model, framework="pt", device=args.device) as f:
            model_keys = f.keys()
            for key in tqdm(model_keys):
                _, new_key = replace_text_encoder_key(key)
                if new_key not in merged_sd:
                    if args.show_skipped and new_key not in first_model_keys:
                        logger.info(f"Skip: {new_key}")
                    continue

                value = f.get_tensor(key)
                merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)

            # enumerate keys not in this model
            model_keys = set(model_keys)
            for key in merged_sd.keys():
                if key in model_keys:
                    continue
                logger.warning(f"Key {key} not in model {model}, use first model's value")
                if key in supplementary_key_ratios:
                    supplementary_key_ratios[key] += ratio
                else:
                    supplementary_key_ratios[key] = ratio

    # add supplementary keys' value (including VAE and TextEncoder)
    if len(supplementary_key_ratios) > 0:
        logger.info("add first model's value")
        with safe_open(args.models[0], framework="pt", device=args.device) as f:
            for key in tqdm(f.keys()):
                _, new_key = replace_text_encoder_key(key)
                if new_key not in supplementary_key_ratios:
                    continue

                if is_unet_key(new_key):  # not VAE or TextEncoder
                    logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")

                value = f.get_tensor(key)  # original key

                if new_key not in merged_sd:
                    merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
                else:
                    merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)

    # save
    output_file = args.output
    if not output_file.endswith(".safetensors"):
        output_file = output_file + ".safetensors"

    logger.info(f"Saving to {output_file}...")

    # convert to save_dtype
    for k in merged_sd.keys():
        merged_sd[k] = merged_sd[k].to(save_dtype)

    save_file(merged_sd, output_file)

    logger.info("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge models")
    parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
    parser.add_argument("--output", type=str, help="Output model")
    parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
    parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
    parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
    parser.add_argument(
        "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
    )
    parser.add_argument(
        "--saving_precision",
        type=str,
        default="float",
        choices=["float", "fp16", "bf16"],
        help="Saving precision, default is float",
    )
    parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")

    args = parser.parse_args()
    merge(args)
