import concurrent.futures
import time
from multiprocessing.pool import ThreadPool

import tqdm

from .my_typing import *


def batchify(data: List[Any], batch_size: int):
    """Split data into batches."""
    return [data[i: i + batch_size] for i in range(0, len(data), batch_size)]


def thread_pool_executor(
        gen_func: Callable,
        batch_inputs: List[Any],
        unordered: bool = True,
        sequential_generation: bool = False,
        show_progress: bool = True,
        num_threads: int = 10,
        request_timeout: int = 60,
        enable_timer: bool = True,
) -> List[Any]:
    """Run a generation function in parallel using a thread pool executor.
    Modified from https://github.com/openai/evals/blob/main/evals/eval.py#L148
    """

    def worker_thread(inputs):
        """
        Worker thread for evaluating a single sample.
        """
        while True:
            executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
            future = executor.submit(gen_func, inputs)
            try:
                result = future.result(timeout=request_timeout)
                return result
            except concurrent.futures.TimeoutError as e:
                executor.shutdown(wait=False)

    start = time.time()
    with ThreadPool(num_threads) as pool:
        if sequential_generation:
            print(f"Running in sequential mode!")
            iter = map(gen_func, batch_inputs)
        else:
            print(f"Running in threaded mode with {num_threads} threads!")
            if unordered:
                iter = pool.imap_unordered(worker_thread, batch_inputs)
            else:
                iter = pool.imap(worker_thread, batch_inputs)
        results = list(
            tqdm.tqdm(iter, total=len(batch_inputs), disable=not show_progress)
        )
    if enable_timer:
        print(f"Run {len(batch_inputs)} calls took {time.time() - start:.2f}s")
    return list(results)
