#!/usr/bin/env python3
import argparse
import json
import os
import subprocess

from dti.constants import DIFFUSERS_MODELS
from dti.utils import find_free_port


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run textual inversion experiment")
    parser.add_argument("-g", "--gpu", type=str, default="7")
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        choices=("sana1.5_1.6b", "sana1.5_4.8b"),
        default="sana1.5_1.6b",
    )
    parser.add_argument("--instances", type=str, nargs="+", default=None)
    parser.add_argument("--max_train_steps", type=int, default=1000)
    parser.add_argument("--lr", type=float, default=5e-3)
    parser.add_argument("--batch_size", type=int, default=4)

    parser.add_argument("--scale", type=str, default="max")
    parser.add_argument("--kappa", type=float, default=0.05)
    parser.add_argument("--desc", type=str, default=None)
    args = parser.parse_args()

    model = args.model.lower()
    args.model = DIFFUSERS_MODELS.get(model, None)
    if args.model is None:
        raise ValueError(f"Model {args.model} not found in DIFFUSERS_MODELS.")
    if args.model == "sana_600m_512":
        args.resolution = 512
    else:
        args.resolution = 1024
    if args.model is None:
        raise ValueError(
            f"Model {args.model} not found in SANA model dictionary for version {args.version}."
        )

    args.runname = f"dti-{model}"
    args.expname = f"dti-{model}"
    return args


def main():
    args = parse_arguments()

    # data = "data/ti.json"
    data = "data/dreambooth.json"
    # data = "data/selected_data.json"
    with open(data, "r") as f:
        data = json.load(f)

    if args.instances is not None:
        for key in list(data):
            if key not in args.instances:
                del data[key]

    outdir = f"output/{args.expname}"
    if args.desc is not None:
        outdir += f"-{args.desc}"
    os.makedirs(outdir, exist_ok=True)

    if args.desc is not None:
        run_name_desc = f"-{args.desc}"
    else:
        run_name_desc = ""

    val_freq = 100

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    accelerate_cmd = [
        "accelerate",
        "launch",
        "--mixed_precision=bf16",
        "--num_processes=1",
        "--num_machines=1",
        "--dynamo_backend=no",
        f"--main_process_port={find_free_port()}",
    ]

    for name, metadata in data.items():
        data_path = metadata["path"]
        # cls = metadata["class"]
        init_token = metadata["initialization"]
        # init_token = cls

        cmd = [
            "scripts/train_sana.py",
            f"--pretrained_model_name_or_path={args.model}",
            f"--train_data_dir={data_path}",
            f"--output_dir=./{outdir}/{name}",
            "--learnable_property=object",
            f"--placeholder_token=<{name}>",
            "--num_vectors=1",
            f"--initializer_token={init_token}",
            # f"--initializer_token={init_token}",
            f"--resolution={args.resolution}",
            f"--save_steps={val_freq}",
            f"--validation_steps={val_freq}",
            "--validation_prompt=a {} on a beach",
            f"--train_batch_size={args.batch_size}",
            "--gradient_accumulation_steps=1",
            f"--learning_rate={args.lr}",
            "--scale_lr",
            f"--max_train_steps={args.max_train_steps}",
            # "--report_to=wandb",
            f"--run_name={args.runname}-{name}{run_name_desc}",
            f"--token_scale={args.scale}",
            f"--kappa={args.kappa}",
            "--seed=42",
            "--zero_pad",
        ]

        # save cmd as text file
        os.makedirs(f"{outdir}/{name}", exist_ok=True)
        cmd_txt = "\n".join(cmd)
        with open(f"{outdir}/{name}/cmd.txt", "w") as file:
            file.write(cmd_txt)

        subprocess.run(accelerate_cmd + cmd)

    # Evaluation.
    cmd = [
        "python",
        "scripts/evaluate.py",
        f"-e={outdir}",
    ]
    subprocess.run(cmd)

    cmd = [
        "python",
        "scripts/evaluate.py",
        f"-e={outdir}",
        "--out_dir=images_complex",
        "--prompt_set=complex",
    ]
    subprocess.run(cmd)


if __name__ == "__main__":
    main()
