import argparse

from diffusers.utils import is_safetensors_available


if is_safetensors_available():
    import safetensors.torch
else:
    raise ImportError("Please install `safetensors`.")

from diffusers import AutoencoderTiny


"""
Example - From the diffusers root directory:

Download the weights:
```sh
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors
$ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors
```

Convert the model:
```sh
$ python scripts/convert_tiny_autoencoder_to_diffusers.py \
    --encoder_ckpt_path  taesd_encoder.safetensors \
    --decoder_ckpt_path taesd_decoder.safetensors \
    --dump_path taesd-diffusers
```
"""

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
    parser.add_argument(
        "--encoder_ckpt_path",
        default=None,
        type=str,
        required=True,
        help="Path to the encoder ckpt.",
    )
    parser.add_argument(
        "--decoder_ckpt_path",
        default=None,
        type=str,
        required=True,
        help="Path to the decoder ckpt.",
    )
    parser.add_argument(
        "--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format."
    )
    args = parser.parse_args()

    print("Loading the original state_dicts of the encoder and the decoder...")
    encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path)
    decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path)

    print("Populating the state_dicts in the diffusers format...")
    tiny_autoencoder = AutoencoderTiny()
    new_state_dict = {}

    # Modify the encoder state dict.
    for k in encoder_state_dict:
        new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]})

    # Modify the decoder state dict.
    for k in decoder_state_dict:
        layer_id = int(k.split(".")[0]) - 1
        new_k = str(layer_id) + "." + ".".join(k.split(".")[1:])
        new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]})

    # Assertion tests with the original implementation can be found here:
    # https://gist.github.com/sayakpaul/337b0988f08bd2cf2b248206f760e28f
    tiny_autoencoder.load_state_dict(new_state_dict)
    print("Population successful, serializing...")
    tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors)
