import argparse
from llm_inference.tasks import TASKS
from loguru import logger
import upath
from tqdm.auto import tqdm


def main(tasks, output_dir):
  output_dir = upath.UPath(output_dir)
  for task_name in tqdm(tasks, total=len(tasks)):
    if task_name not in TASKS:
      logger.error(f"Task {task_name} not found")
      continue

    task = TASKS[task_name]
    logger.info(f"Loading dataset for task '{task_name}'")
    dataset = task.load_dataset()
    logger.info("num_rows: {}, columns: {}", len(dataset), dataset.column_names)
    logger.info(f"Exporting {task_name} to {output_dir}")
    dataset.to_parquet(
      (output_dir / f"{task_name}.parquet").as_posix(),
    )


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser = argparse.ArgumentParser()

  # Create a mutually exclusive group
  group = parser.add_mutually_exclusive_group(required=True)
  group.add_argument("--list-tasks", action="store_true", help="List available tasks")
  group.add_argument(
    "--tasks", type=str, help="Comma-separated list of tasks to export"
  )
  parser.add_argument("--output-dir", type=str)

  args = parser.parse_args()
  # Check if output-dir is provided when --tasks is used
  if args.tasks and not args.output_dir:
    parser.error("--output-dir is required when using --tasks")

  if args.list_tasks:
    all_tasks = list(TASKS.keys())
    logger.info("Available tasks:\n{}", all_tasks)
  else:
    main(args.tasks.split(","), args.output_dir)
