#!/usr/bin/env python3
import argparse
import json
import os
import subprocess

from dti.utils import find_free_port


def parse_args():
    parser = argparse.ArgumentParser(description="Run textual inversion experiment")
    parser.add_argument("-g", "--gpu", type=str, default="7")
    parser.add_argument("-r", "--resolution", type=int, default=768)
    parser.add_argument("--instances", type=str, nargs="+", default=None)
    parser.add_argument("--desc", type=str, default=None)

    parser.add_argument("--max_train_steps", type=int, default=200)
    parser.add_argument("--lr", type=float, default=2e-2)
    parser.add_argument("--min_lr", type=float, default=2e-2)

    parser.add_argument("--kappa", type=float, default=0.1)
    args = parser.parse_args()

    args.model = "stabilityai/stable-diffusion-xl-base-1.0"
    args.expname = "dti-sdxl-style"
    args.runname = "dti-sdxl-style"
    return args


def main():
    args = parse_args()

    full_data = "data/styledrop.json"
    with open(full_data, "r") as f:
        full_data = json.load(f)

    data = {}
    if args.instances is not None:
        for key in args.instances:
            data[key] = full_data[key]
    else:
        data = full_data

    outdir = f"output/{args.expname}"
    args.desc = "sdrpinit"
    if args.desc is not None:
        outdir += f"-{args.desc}"
        run_name_desc = f"-{args.desc}"
    else:
        run_name_desc = ""

    os.makedirs(outdir, exist_ok=True)

    val_step = 40

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    accelerate_cmd = [
        "accelerate",
        "launch",
        "--mixed_precision=bf16",
        "--num_processes=1",
        "--num_machines=1",
        "--dynamo_backend=no",
        f"--main_process_port={find_free_port()}",
    ]

    for name, sample in data.items():
        data_path = sample["path"]
        # data_path = data_path.replace("dreambooth", "dreambooth_masked")
        # print(data_path)
        init_token = sample["initialization"]

        cmd = [
            "scripts/train_sdxl.py",
            f"--pretrained_model_name_or_path={args.model}",
            # f"--train_data_dir={data_path}",
            "--train_data_dir=data/styledrop.json",
            f"--instance={name}",
            f"--output_dir=./{outdir}/{name}",
            "--learnable_property=object",
            f"--placeholder_token=<{name}>",
            # "--num_vectors=2",
            # f"--initializer_token={cls}",
            f"--initializer_token={init_token}",
            "--mixed_precision=bf16",
            f"--resolution={args.resolution}",
            f"--save_steps={val_step}",
            f"--validation_steps={val_step}",
            "--validation_prompt=a teddy bear in {} style",
            "--train_batch_size=4",
            "--gradient_accumulation_steps=1",
            f"--max_train_steps={args.max_train_steps}",
            f"--learning_rate={args.lr}",  # RSGD.
            "--scale_lr",
            "--lr_scheduler=constant",
            f"--lr_eta_min={args.min_lr}",
            # "--report_to=wandb",
            f"--run_name={args.runname}-{name}{run_name_desc}",
            "--seed=42",
            f"--kappa={args.kappa}",
            "--zero_pad",
            "--init_method=token",  # token, random, mean
        ]

        # save cmd as text file
        os.makedirs(f"{outdir}/{name}", exist_ok=True)
        cmd_txt = "\n".join(cmd)
        with open(f"{outdir}/{name}/cmd.txt", "w") as file:
            file.write(cmd_txt)

        subprocess.run(accelerate_cmd + cmd)

    # Evaluation.
    cmd = [
        "python",
        "scripts/evaluate.py",
        f"-e={outdir}",
        "--prompt_set=style",
        "--train_data=data/qwen_sdrp.json",
    ]
    subprocess.run(cmd)


if __name__ == "__main__":
    main()
