import concurrent.futures
import time
from argparse import Namespace
from multiprocessing.pool import ThreadPool

import tqdm

from .my_typing import *


def threading_register_args(parser, default_batch_size=8, default_timeout=3000):
    parser.add_argument(
        "--batch-size",
        "-bs",
        type=int,
        default=default_batch_size,
        help="The number of examples to generate in parallel",
    )
    parser.add_argument(
        "--threading-request-timeout",
        "-trt",
        type=int,
        default=default_timeout,
        help="Timeout for a single generation request",
    )
    parser.add_argument(
        "--sequential-generation",
        "-sg",
        action="store_true",
        help="Run generation sequentially instead of in parallel",
    )
    parser.add_argument(
        "--disable-progress",
        "-xprog",
        action="store_true",
        help="Disable progress bar for multi-threaded generation",
    )
    parser.add_argument(
        "--disable-timer",
        "-xtimer",
        action="store_true",
        help="Disable timer for multi-threaded generation",
    )


def batchify(data: List[Any], batch_size: int):
    """Split data into batches."""
    return [data[i : i + batch_size] for i in range(0, len(data), batch_size)]


def thread_pool_executor(
    gen_func: Callable,
    batch_inputs: List[Any],
    unordered: bool = True,
    sequential_generation: bool = False,
    show_progress: bool = True,
    num_threads: int = 10,
    request_timeout: int = 60,
    enable_timer: bool = True,
) -> List[Any]:
    """Run a generation function in parallel using a thread pool executor.
    Modified from https://github.com/openai/evals/blob/main/evals/eval.py#L148
    """

    def worker_thread(inputs):
        """
        Worker thread for evaluating a single sample.
        """
        while True:
            executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
            future = executor.submit(gen_func, inputs)
            try:
                result = future.result(timeout=request_timeout)
                return result
            except concurrent.futures.TimeoutError as e:
                executor.shutdown(wait=False)

    start = time.time()
    with ThreadPool(num_threads) as pool:
        if sequential_generation:
            print(f"Running in sequential mode!")
            iter = map(gen_func, batch_inputs)
        else:
            print(f"Running in threaded mode with {num_threads} threads!")
            if unordered:
                iter = pool.imap_unordered(worker_thread, batch_inputs)
            else:
                iter = pool.imap(worker_thread, batch_inputs)
        results = list(
            tqdm.tqdm(iter, total=len(batch_inputs), disable=not show_progress)
        )
    if enable_timer:
        print(f"Run {len(batch_inputs)} calls took {time.time() - start:.2f}s")
    return list(results)


def thread_pool_executor_with_args(
    gen_func: Callable, batch_inputs: List[Any], args: Namespace, unordered: bool = True
):
    """Run a generation function in parallel using a thread pool executor."""
    return thread_pool_executor(
        gen_func,
        batch_inputs,
        unordered=unordered,
        sequential_generation=args.sequential_generation,
        show_progress=not args.disable_progress,
        num_threads=args.batch_size,
        request_timeout=args.threading_request_timeout,
        enable_timer=not args.disable_timer,
    )
