#!/usr/bin/env python3
import argparse
import json
import os
import subprocess

from dti.utils import find_free_port


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run textual inversion experiment")
    parser.add_argument("-g", "--gpu", type=str, default="7")
    parser.add_argument("-r", "--resolution", type=int, default=768)
    parser.add_argument("--instances", type=str, nargs="+", default=None)
    parser.add_argument("--desc", type=str, default=None)

    parser.add_argument("--max_train_steps", type=int, default=300)
    parser.add_argument("--emb_lr", type=float, default=8e-2)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--dco_beta", type=float, default=0.0)

    parser.add_argument("--kappa", type=float, default=0.1)
    parser.add_argument("--rank", type=int, default=4)
    args = parser.parse_args()

    args.model = "stabilityai/stable-diffusion-xl-base-1.0"
    args.expname = "lora_dti-sdxl"
    args.runname = "lora_dti-sdxl"
    return args


def main():
    args = parse_arguments()

    # full_data = "data/ti.json"
    full_data = "data/dreambooth.json"
    # full_data = "data/selected_data.json"
    with open(full_data, "r") as f:
        full_data = json.load(f)

    data = {}
    if args.instances is not None:
        for key in args.instances:
            data[key] = full_data[key]
    else:
        data = full_data
    print(data)

    outdir = f"output/{args.expname}"
    if args.desc is not None:
        outdir += f"-{args.desc}"
    os.makedirs(outdir, exist_ok=True)

    if args.desc is not None:
        run_name_desc = f"-{args.desc}"
    else:
        run_name_desc = ""

    if args.resolution == 1024:
        batch_size = 1
        num_accum = 1
    else:
        batch_size = 2
        num_accum = 2
    val_freq = args.max_train_steps // 5

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    accelerate_cmd = [
        "accelerate",
        "launch",
        "--mixed_precision=bf16",
        "--num_processes=1",
        "--num_machines=1",
        "--dynamo_backend=no",
        f"--main_process_port={find_free_port()}",
    ]

    for name, metadata in data.items():
        data_path = metadata["path"]
        cls = metadata["class"]
        cmd = [
            "scripts/train_lora.py",
            f"--pretrained_model_name_or_path={args.model}",
            f"--train_data_dir={data_path}",
            f"--output_dir=./{outdir}/{name}",
            "--learnable_property=object",
            f"--placeholder_token=<{name}>",
            "--num_vectors=1",
            f"--initializer_token={cls}",
            f"--resolution={args.resolution}",
            f"--save_steps={val_freq}",
            f"--validation_steps={val_freq}",
            "--validation_prompt=a {} with Japanese modern city street in the background",
            f"--train_batch_size={batch_size}",
            f"--gradient_accumulation_steps={num_accum}",
            f"--max_train_steps={args.max_train_steps}",
            # "--emb_learning_rate=5e-4",
            # "--scale_lr",
            # "--emb_learning_rate=0.002",
            f"--emb_learning_rate={args.emb_lr}",
            f"--learning_rate={args.lr}",
            "--lr_scheduler=constant_with_warmup",
            "--lr_warmup_steps=100",
            # "--report_to=wandb",
            f"--run_name={args.runname}-{name}{run_name_desc}",
            f"--lora_rank={args.rank}",
            f"--dco_beta={args.dco_beta}",
        ]

        # save cmd as text file
        os.makedirs(f"{outdir}/{name}", exist_ok=True)
        cmd_txt = "\n".join(cmd)
        with open(f"{outdir}/{name}/cmd.txt", "w") as file:
            file.write(cmd_txt)

        subprocess.run(accelerate_cmd + cmd)

    # Evaluation.
    cmd = [
        "python",
        "scripts/evaluate.py",
        f"-e={outdir}",
    ]
    subprocess.run(cmd)


if __name__ == "__main__":
    main()
