import os

import fire
import torch
from lora_diffusion import (
    DEFAULT_TARGET_REPLACE,
    TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
    UNET_DEFAULT_TARGET_REPLACE,
    convert_loras_to_safeloras_with_embeds,
)

_target_by_name = {
    "unet": UNET_DEFAULT_TARGET_REPLACE,
    "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
}


def convert(*paths, outpath, overwrite=False, **settings):
    """
    Converts one or more pytorch Lora and/or Textual Embedding pytorch files
    into a safetensor file.

    Pass all the input paths as arguments. Whether they are Textual Embedding
    or Lora models will be auto-detected.

    For Lora models, their name will be taken from the path, i.e.
        "lora_weight.pt" => unet
        "lora_weight.text_encoder.pt" => text_encoder

    You can also set target_modules and/or rank by providing an argument prefixed
    by the name.

    So a complete example might be something like:

    ```
    python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
    ```
    """
    modelmap = {}
    embeds = {}

    if os.path.exists(outpath) and not overwrite:
        raise ValueError(
            f"Output path {outpath} already exists, and overwrite is not True"
        )

    for path in paths:
        data = torch.load(path)

        if isinstance(data, dict):
            print(f"Loading textual inversion embeds {data.keys()} from {path}")
            embeds.update(data)

        else:
            name_parts = os.path.split(path)[1].split(".")
            name = name_parts[-2] if len(name_parts) > 2 else "unet"

            model_settings = {
                "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
                "rank": 4,
            }

            prefix = f"{name}."
            
            arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
            model_settings = { **model_settings, **arg_settings }

            print(f"Loading Lora for {name} from {path} with settings {model_settings}")

            modelmap[name] = (
                path,
                model_settings["target_modules"],
                model_settings["rank"],
            )

    convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)


def main():
    fire.Fire(convert)


if __name__ == "__main__":
    main()
