#!/usr/bin/env python3
import argparse
import json
import os
import subprocess

from dti.utils import find_free_port


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run textual inversion experiment")
    parser.add_argument("-g", "--gpu", type=str, default="7")
    parser.add_argument("-r", "--resolution", type=int, default=768)
    parser.add_argument("--instances", type=str, nargs="+", default=None)
    parser.add_argument("--desc", type=str, default=None)

    parser.add_argument("--max_train_steps", type=int, default=500)
    parser.add_argument("--lr", type=float, default=2e-2)
    parser.add_argument("--min_lr", type=float, default=2e-2)

    parser.add_argument("--reparam", type=str, default="true")
    parser.add_argument("--scale", type=str, default="max")
    parser.add_argument("--kappa", type=float, default=0.1)
    parser.add_argument("--init_method", type=str, default="token")
    parser.add_argument("--train_magnitude", action="store_true")
    parser.add_argument("--mag_lr_multiplier", type=float, default=1.0)
    parser.add_argument("--adamw", action="store_true")
    args = parser.parse_args()

    args.model = "stabilityai/stable-diffusion-xl-base-1.0"
    args.expname = "dti-sdxl"
    args.runname = "dti-sdxl"
    return args


def main():
    args = parse_arguments()

    # full_data = "data/ti.json"
    full_data = "data/dreambooth.json"
    # full_data = "data/selected_data.json"
    with open(full_data, "r") as f:
        full_data = json.load(f)

    data = {}
    if args.instances is not None:
        for key in args.instances:
            data[key] = full_data[key]
    else:
        data = full_data

    outdir = f"output/{args.expname}"
    if args.desc is not None:
        outdir += f"-{args.desc}"
    os.makedirs(outdir, exist_ok=True)

    if args.desc is not None:
        run_name_desc = f"-{args.desc}"
    else:
        run_name_desc = ""

    val_step = args.max_train_steps // 5

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    accelerate_cmd = [
        "accelerate",
        "launch",
        "--mixed_precision=bf16",
        "--num_processes=1",
        "--num_machines=1",
        "--dynamo_backend=no",
        f"--main_process_port={find_free_port()}",
    ]

    for name, metadata in data.items():
        data_path = metadata["path"]
        # data_path = data_path.replace("dreambooth", "dreambooth_masked")
        # print(data_path)
        cls = metadata["class"]
        init_token = cls  # NOTE: this is the default option.
        # init_token = metadata["initialization"]

        cmd = [
            "scripts/train_sdxl.py",
            f"--pretrained_model_name_or_path={args.model}",
            f"--train_data_dir={data_path}",
            # "--train_data_dir=data/dreambooth.json",
            # f"--instance={name}",
            f"--output_dir=./{outdir}/{name}",
            "--learnable_property=object",
            f"--placeholder_token=<{name}>",
            # "--num_vectors=2",
            f"--initializer_token={init_token}",
            "--mixed_precision=bf16",
            f"--resolution={args.resolution}",
            f"--save_steps={val_step}",
            f"--validation_steps={val_step}",
            "--validation_prompt=a {} with Japanese modern city street in the background",
            "--train_batch_size=4",
            "--gradient_accumulation_steps=1",
            f"--max_train_steps={args.max_train_steps}",
            # "--gradient_accumulation_steps=5",  # to measure task affinity
            # "--learning_rate=5e-4",
            # "--scale_lr",
            # "--adam_beta1=0.0",  # default: 0.9
            # "--adam_weight_decay=0.0",  # default: 1e-2
            # "--lr_scheduler=cosine",
            f"--learning_rate={args.lr}",  # RSGD.
            "--scale_lr",
            "--lr_scheduler=constant",
            f"--lr_eta_min={args.min_lr}",
            # "--report_to=wandb",
            f"--run_name={args.runname}-{name}{run_name_desc}",
            "--seed=42",
            "--zero_pad",
            f"--init_method={args.init_method}",  # token, random, mean
            f"--init_scale={args.scale}",
            f"--kappa={args.kappa}",
            f"--reparameterize={args.reparam}",
        ]
        if args.adamw:
            cmd += ["--use_adam"]
        if args.train_magnitude:
            cmd += [
                "--train_magnitude",
                f"--mag_lr_multiplier={args.mag_lr_multiplier}",
            ]

        # Save cmd as text file.
        os.makedirs(f"{outdir}/{name}", exist_ok=True)
        cmd_txt = "\n".join(cmd)
        with open(f"{outdir}/{name}/cmd.txt", "w") as file:
            file.write(cmd_txt)

        subprocess.run(accelerate_cmd + cmd)

    # Evaluation.
    for ckpt in range(args.max_train_steps, val_step - 1, -val_step):
        cmd = [
            "python", "scripts/evaluate.py",
            f"-e={outdir}",
            f"--checkpoint={ckpt}",
        ]
        if args.instances is not None:
            cmd += ["--instances"] + args.instances
        subprocess.run(cmd)
        break

    cmd = [
        "python", "scripts/evaluate.py",
        f"-e={outdir}",
        "--prompt_set=complex",
        "--out_dir=images_complex",
    ]
    subprocess.run(cmd)


if __name__ == "__main__":
    main()
