"""Builds a consolidated list of tasks that were not sucessfully executed for future runs"""

import argparse
import json
import os
from pathlib import Path
import sys
from typing import TypedDict
import pandas as pd
import glob
import re

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from utils.file_utils import (
    find_files,
    get_attribute_from_args_file,
    get_task_ids_from_csv,
    get_args,
    get_task_id_from_file_path,
    get_task_ids_from_dir,
    get_ids_from_tst_config_list,
)


def get_attempted_tasks(execution_dir):
    args = get_args(execution_dir)

    all_tasks = set()
    task_list = args.get("task_list")
    if task_list and isinstance(task_list, list):
        all_tasks.update(get_ids_from_tst_config_list(tst_config_list=task_list))
    else:
        start_idx = max(args.get("start_idx", 0), 0)
        task_ids = sorted(list(get_task_ids_from_dir(args["test_config_base_dir"])))
        end_idx = min(max(args.get("end_idx", 0), 0), len(task_ids))
        all_tasks.update(task_ids[start_idx:end_idx])

    if "attempted_tasks.txt" in os.listdir(execution_dir):
        all_tasks.update(get_ids_from_tst_config_list(txt_path=os.path.join(execution_dir, "attempted_tasks.txt")))

    if "htmls" in os.listdir(execution_dir):
        for file in os.listdir(os.path.join(execution_dir, "htmls")):
            task_id = get_task_id_from_file_path(file)
            all_tasks.add(task_id)

    return set(int(task_id) for task_id in all_tasks)


def get_failed_execution_tasks(execution_dir):
    failed_tasks = set()
    if "unfinished_tasks.txt" in os.listdir(execution_dir):
        failed_tasks.update(get_ids_from_tst_config_list(txt_path=os.path.join(execution_dir, "unfinished_tasks.txt")))

    if "error.txt" in os.listdir(execution_dir):
        with open(os.path.join(execution_dir, "error.txt"), "r") as file:
            # re match all **<task_id>.json
            for line in file:
                match = re.search(r"(\d+)\.json$", line)
                if match:
                    failed_tasks.add(int(match.group(1)))
    return set(int(task_id) for task_id in failed_tasks)


def get_successful_execution_tasks(execution_dir):
    successful_tasks = set()

    if os.path.exists(os.path.join(execution_dir, "summary_data.csv")):
        successful_tasks.update(get_task_ids_from_csv(os.path.join(execution_dir, "summary_data.csv")))

    # if os.path.exists(os.path.join(execution_dir, "scores_per_round.json")):
    #     with open(os.path.join(execution_dir, "scores_per_round.json"), "r") as file:
    #         data = json.load(file)
    #         successful_tasks.update(data.keys())

    return set(int(task_id) for task_id in successful_tasks)


class UnfinishedTaskEntry(TypedDict):
    tasks: set[int]
    source_dir: str


def get_unfinished_tasks(experiment_dir_path) -> dict[str, UnfinishedTaskEntry]:
    unfinished_per_tst_config = {}
    finished_per_tst_config = {}
    args_files = find_files(experiment_dir_path, "args.json", upwards=False, downwards=True)

    for args_file in args_files:
        tst_config = get_attribute_from_args_file(args_file_path=args_file, attribute="test_config_base_dir")
        single_run_dir = Path(args_file).parent
        if tst_config not in unfinished_per_tst_config:
            unfinished_per_tst_config[tst_config] = UnfinishedTaskEntry(tasks=set(), source_dir=str(single_run_dir))
            finished_per_tst_config[tst_config] = set()
        unfinished_per_tst_config[tst_config]["tasks"].update(get_attempted_tasks(single_run_dir))
        unfinished_per_tst_config[tst_config]["tasks"].update(get_failed_execution_tasks(single_run_dir))
        finished_per_tst_config[tst_config].update(get_successful_execution_tasks(single_run_dir))

    for tst_config in unfinished_per_tst_config.keys():
        unfinished_per_tst_config[tst_config]["tasks"] = (
            unfinished_per_tst_config[tst_config]["tasks"] - finished_per_tst_config[tst_config]
        )

    return unfinished_per_tst_config


def consolidate_unfinished_tasks(experiment_dir_path: str | Path, save_dir: str | Path = "") -> None:
    print(f"Consolidating unfinished tasks from {experiment_dir_path}")

    unfinished_per_tst_config = get_unfinished_tasks(experiment_dir_path)

    for tst_config in unfinished_per_tst_config.keys():
        unfinished_tasks = unfinished_per_tst_config[tst_config]["tasks"]
        parent_dir = unfinished_per_tst_config[tst_config]["source_dir"]
        domain = tst_config.split("/")[-1].split("_")[-1]

        if save_dir:
            write_unfinished_tasks_to_file(
                unfinished_tasks, tst_config, os.path.join(save_dir, f"unfinished_tasks_{domain}.txt")
            )
        else:
            write_unfinished_tasks_to_file(
                unfinished_tasks, tst_config, os.path.join(parent_dir, f"unfinished_tasks_{domain}.txt")
            )


def write_unfinished_tasks_to_file(unfinished_tasks, test_config_base_dir, output_file):
    if not unfinished_tasks:
        # check if output_file exists
        if os.path.exists(output_file):
            os.remove(output_file)
        return

    unfinished_tasks = sorted(list(unfinished_tasks))
    with open(output_file, "w") as file:
        file.write(f"{test_config_base_dir}\n")
        for task_id in unfinished_tasks:
            file.write(f"{task_id}\n")
    print(f"Finished writing unfinished tasks to {output_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--d", type=str, required=True)
    parser.add_argument("--o", type=str, default="")
    args = parser.parse_args()

    args.d = "results/gemini-2.0-flash-001/p_run-shopping-2025-02-24-2033"
    consolidate_unfinished_tasks(args.d, args.o)
