#!/usr/bin/env python3
import argparse
import json
import os
import subprocess


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run textual inversion experiment")
    parser.add_argument("-g", "--gpu", type=str, default="7")
    parser.add_argument("-m", "--model", type=str, default="1.5")
    parser.add_argument("--instances", type=str, nargs="+", default=None)
    parser.add_argument("--desc", type=str, default=None)

    parser.add_argument("--total_steps", type=int, default=500)
    parser.add_argument("--lr", type=float, default=5e-3)
    parser.add_argument("--min_lr", type=float, default=5e-3)

    parser.add_argument("--reparam", type=str, default="true")
    parser.add_argument("--scale", type=str, default="max")
    parser.add_argument("--kappa", type=float, default=0.1)
    parser.add_argument("--kappa_min", type=float, default=1e-5)
    args = parser.parse_args()

    model = args.model.lower()
    if model == "1.5":
        args.model = "stable-diffusion-v1-5/stable-diffusion-v1-5"
        args.resolution = 512
    elif model == "2.1":
        args.model = "stabilityai/stable-diffusion-2-1"
        args.resolution = 768
    elif model == "2.1base":
        args.model = "stabilityai/stable-diffusion-2-1-base"
        args.resolution = 512

    args.expname = f"dti-sd{model}"
    args.runname = f"dti-sd{model}"
    return args


def main():
    args = parse_arguments()

    # full_data = "data/ti.json"
    full_data = "data/dreambooth.json"
    # full_data = "data/selected_data.json"
    with open(full_data, "r") as f:
        full_data = json.load(f)

    data = {}
    if args.instances is not None:
        for key in args.instances:
            data[key] = full_data[key]
    else:
        data = full_data

    outdir = f"output/{args.expname}"
    if args.desc is not None:
        outdir += f"-{args.desc}"

    os.makedirs(outdir, exist_ok=True)

    if args.desc is not None:
        run_name_desc = f"-{args.desc}"
    else:
        run_name_desc = ""

    val_step = args.total_steps // 5

    num_gpu = len(args.gpu.split(","))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    torchrun_cmd = [
        "torchrun",
        "--rdzv-backend=c10d",
        "--rdzv-endpoint=localhost:0",
        f"--nproc-per-node={num_gpu}",
    ]
    for name, metadata in data.items():
        data_path = metadata["path"]
        # data_path = data_path.replace("dreambooth", "dreambooth_masked")
        # print(data_path)
        cls = metadata["class"]
        init_token = cls  # NOTE: this is the default option.
        # init_token = metadata["initialization"]

        cmd = [
            "scripts/train_sd.py",
            f"--pretrained_model_name_or_path={args.model}",
            f"--train_data_dir={data_path}",
            # "--train_data_dir=data/dreambooth.json",
            # f"--instance={name}",
            f"--output_dir=./{outdir}/{name}",
            "--learnable_property=object",
            f"--placeholder_token=<{name}>",
            # "--num_vectors=2",
            f"--initializer_token={init_token}",
            "--mixed_precision=no",
            f"--resolution={args.resolution}",
            f"--save_steps={val_step}",
            f"--validation_steps={val_step}",
            "--validation_prompt=a {} in the jungle",
            "--train_batch_size=4",
            "--gradient_accumulation_steps=1",
            f"--max_train_steps={args.total_steps}",
            # "--learning_rate=5e-4",
            # "--scale_lr",
            # "--adam_beta1=0.0",  # default: 0.9
            # "--adam_weight_decay=0.0",  # default: 1e-2
            # "--lr_scheduler=cosine",
            f"--learning_rate={args.lr}",  # RSGD.
            "--scale_lr",
            "--lr_scheduler=constant",
            f"--lr_eta_min={args.min_lr}",
            # "--report_to=wandb",
            f"--run_name={args.runname}-{name}{run_name_desc}",
            "--seed=42",
            "--zero_pad",
            "--init_method=token",  # token, random, mean
            f"--init_scale={args.scale}",
            f"--kappa={args.kappa}",
            f"--kappa_min={args.kappa_min}",
            f"--reparameterize={args.reparam}",
        ]

        # save cmd as text file
        os.makedirs(f"{outdir}/{name}", exist_ok=True)
        cmd_txt = "\n".join(cmd)
        with open(f"{outdir}/{name}/cmd.txt", "w") as file:
            file.write(cmd_txt)

        subprocess.run(torchrun_cmd + cmd)

    # Evaluation.
    for ckpt in range(args.total_steps, val_step - 1, -val_step):
        cmd = [
            "python scripts/evaluate.py",
            f"-e {outdir}",
            f"--checkpoint {ckpt}",
        ]
        subprocess.run(cmd)


if __name__ == "__main__":
    main()
