from typing import List, Literal
import dotenv
import os
import glob
import pandas as pd

from radar.data import datamodel
from radar import utils
from tqdm import tqdm

dotenv.load_dotenv()


def load_task_instances(
    split: Literal["full", "tasks", "sizes"],
) -> List[datamodel.TaskInstance]:
    base_data_folder = os.getenv("DATASET_FOLDER")
    if not os.path.exists(base_data_folder):
        raise FileNotFoundError(f"Could not find dataset folder: {base_data_folder}")

    if split == "full":
        task_folder = os.path.join(base_data_folder, "radar")
    elif split == "tasks":
        task_folder = os.path.join(base_data_folder, "radar_tasks")
    elif split == "sizes":
        task_folder = os.path.join(base_data_folder, "radar_sizes")

    subfolders = os.listdir(task_folder)
    task_instances: List[datamodel.TaskInstance] = []
    for subfolder in subfolders:
        task_files = glob.glob(os.path.join(task_folder, subfolder, "*.json"))
        task_instances.extend(
            datamodel.TaskInstance(**utils.read_json(task_file))
            for task_file in task_files
        )
    df_rows = []
    for task_instance in tqdm(task_instances, desc="Loading task instances"):
        df_rows.append(
            {
                "task_id": task_instance.task_id,
                "query": task_instance.query,
                "artifact_types": task_instance.artifact_types,
                "query_cols": task_instance.query_cols,
                "table_num_tokens": task_instance.base_data_num_tokens,
                "table_token_bucket": task_instance.base_data_token_bucket,
                "table_num_cols": task_instance.num_cols,
                "table_num_rows": task_instance.num_rows,
                "perturbation_note": task_instance.perturbation_note,
                "answer": task_instance.answer,
            }
        )
    df_meta = pd.DataFrame(df_rows)
    return task_instances, df_meta


if __name__ == "__main__":
    task_instances, df_meta = load_task_instances("tasks")
    print(df_meta)
