import embedder
from lmms_eval.utils import simple_parse_args_string
from lmms_eval.tasks import initialize_tasks, include_path, get_task_dict, ConfigurableTask
from lmms_eval.api.registry import ALL_TASKS, GROUP_REGISTRY

import argparse
import os

import torch.distributed as dist


def rank0_print(*args):
    if dist.is_initialized():
        if dist.get_rank() == 0:
            print(f"Rank {dist.get_rank()}: ", *args)
    else:
        print(*args)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str)
    parser.add_argument("--output_path", type=str)
    parser.add_argument("--tasks", type=str, required=False, default="")
    parser.add_argument("--data_path", type=str, required=False, default="")
    parser.add_argument("--image_folder", type=str, required=False, default="")
    parser.add_argument("--embedder_kwargs", type=str, default="")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    embedder_name = args.name
    output_path = args.output_path
    if args.tasks.lower().strip() == "all":
        initialize_tasks()
        for task in list(ALL_TASKS):
            if task in GROUP_REGISTRY:
                ALL_TASKS.remove(task)
        tasks = list(ALL_TASKS)
    else:
        tasks = args.tasks.split(",")

    cached_idx = []
    for idx in range(len(tasks)):
        if os.path.exists(os.path.join(output_path, f"{tasks[idx]}_embed.npy")):
            rank0_print(f"Task {tasks[idx]} exists in cache folder, load from cache")
            cached_idx.append(idx)
    tasks = [tasks[idx] for idx in range(len(tasks)) if idx not in cached_idx]
    rank0_print(f"Tasks : {tasks}")
    embedder_kwargs = simple_parse_args_string(args.embedder_kwargs)

    embedder_cls = getattr(embedder, embedder_name)
    embedder_obj = embedder_cls(name=embedder_name, output_path=output_path, **embedder_kwargs)
    for task in tasks:
        embedder_obj.embed_task(task)
