import argparse, os

from dotenv import load_dotenv
from build_dataset import main as build_dataset
from print_pulls import main as print_pulls
from multiprocessing import Pool
from typing import Dict, List


load_dotenv()


def split_instances(input_list: List, n: int) -> List:
    """
    Split a list into n approximately equal length sublists

    Args:
        input_list (list): List to split
        n (int): Number of sublists to split into
    Returns:
        result (list): List of sublists
    """
    avg_length = len(input_list) // n
    remainder = len(input_list) % n
    result, start = [], 0

    for i in range(n):
        length = avg_length + 1 if i < remainder else avg_length
        sublist = input_list[start : start + length]
        result.append(sublist)
        start += length

    return result


def construct_data_files(data: Dict):
    """
    Logic for combining multiple .all PR files into a single fine tuning dataset

    Args:
        data (dict): Dictionary containing the following keys:
            repos (list): List of repositories to retrieve instruction data for
            path_prs (str): Path to save PR data files to
            path_tasks (str): Path to save task instance data files to
            token (str): GitHub token to use for API requests
    """
    repos, path_prs, path_tasks, token = (
        data["repos"],
        data["path_prs"],
        data["path_tasks"],
        data["token"],
    )
    for repo in repos:
        repo = repo.strip(",").strip()
        repo_name = repo.split("/")[1]
        try:
            path_pr = os.path.join(path_prs, f"{repo_name}-prs.jsonl")
            if not os.path.exists(path_pr):
                print(f"Pull request data for {repo} not found, creating...")
                print_pulls(repo, path_pr, token)
                print(f"Successfully saved PR data for {repo} to {path_pr}")
            else:
                print(
                    f"Pull request data for {repo} already exists at {path_pr}, skipping..."
                )

            path_task = os.path.join(path_tasks, f"{repo_name}-task-instances.jsonl")
            if not os.path.exists(path_task):
                print(f"Task instance data for {repo} not found, creating...")
                build_dataset(path_pr, path_task, token)
                print(
                    f"Successfully saved task instance data for {repo} to {path_task}"
                )
            else:
                print(
                    f"Task instance data for {repo} already exists at {path_task}, skipping..."
                )
        except Exception as e:
            print(f"Something went wrong for {repo}, skipping: {e}")
            pass


def main(repos, path_prs, path_tasks):
    """
    Spawns multiple threads given multiple GitHub tokens for collecting fine tuning data

    Args:
        repos (list): List of repositories to retrieve instruction data for
        path_prs (str): Path to save PR data files to
        path_tasks (str): Path to save task instance data files to
    """
    path_prs, path_tasks = os.path.abspath(path_prs), os.path.abspath(path_tasks)
    print(f"Will save PR data to {path_prs}")
    print(f"Will save task instance data to {path_tasks}")
    print(f"Received following repos to create task instances for: {repos}")

    tokens = os.getenv("GITHUB_TOKENS").split(",")
    data_task_lists = split_instances(repos, len(tokens))

    data_pooled = [
        {"repos": repos, "path_prs": path_prs, "path_tasks": path_tasks, "token": token}
        for repos, token in zip(data_task_lists, tokens)
    ]

    with Pool(len(tokens)) as p:
        p.map(construct_data_files, data_pooled)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--repos", nargs="+", help="List of repositories to create task instances for"
    )
    parser.add_argument("--path_prs", type=str, help="Path to save PR data files to")
    parser.add_argument(
        "--path_tasks", type=str, help="Path to save task instance data files to"
    )
    args = parser.parse_args()
    main(**vars(args))
