from argparse import ArgumentParser, Namespace

import tqdm
from toolemu.utils import append_jsonl, batchify, thread_pool_executor
from toolemu.utils.my_typing import *


class FuncExecutor:
    def __init__(
        self,
        num_retries: int = 0,
        batch_size: int = 5,
        threading_request_timeout: int = 3000,
        unordered: bool = False,
        sequential_generation: bool = False,
        disable_progress: bool = False,
        disable_timer: bool = False,
        verbose: bool = True,
    ):
        self._num_retries = num_retries
        self._batch_size = batch_size
        self._threading_request_timeout = threading_request_timeout
        self._unordered = unordered
        self._sequential_generation = sequential_generation
        self._disable_progress = disable_progress
        self._disable_timer = disable_timer
        self._verbose = verbose

    def write_to_file(self, output_file: str, results: Any):
        """Write the results to the output file."""
        append_jsonl(output_file, results)

    def run(
        self,
        func: Callable,
        dataset: List[Any],
        output_file: str,
        result_parse_func: Callable = None,
    ):
        """Run the function on the dataset."""
        failed_results = []
        for retries in range(self._num_retries + 1):
            batched_dataset = batchify(dataset, self._batch_size)
            if self._verbose:
                msg = "Executing the function" if retries == 0 else "Retrying"
                print(f"{msg} for {len(dataset)} data points.")
            remaining_dataset = []
            failed_results = []

            for batch_inputs in tqdm.tqdm(batched_dataset):
                if self._batch_size <= 1:  # directly call the function
                    results = [func(batch_inputs[0])]
                else:
                    results = thread_pool_executor(
                        func,
                        batch_inputs,
                        unordered=self._unordered,
                        sequential_generation=self._sequential_generation,
                        show_progress=not self._disable_progress,
                        num_threads=self._batch_size,
                        request_timeout=self._threading_request_timeout,
                        enable_timer=not self._disable_timer,
                    )
                success_count = 0
                for res in results:
                    if result_parse_func is not None:
                        failed_item, result = result_parse_func(res)
                    else:
                        failed_item, result = res
                    if failed_item is None:  # the function execution is successful
                        self.write_to_file(output_file, result)
                        success_count += 1
                    else:  # the function execution failed, to be retried
                        remaining_dataset.append(failed_item)
                        failed_results.append(result)

                if self._verbose:
                    print(
                        f"Saved {success_count} successful execution results "
                        f"(out of {len(results)}) to {output_file}."
                    )
            dataset = remaining_dataset
            if len(dataset) == 0:
                break
            if self._verbose:
                print(
                    f"Failed to execute the function for {len(dataset)} "
                    f"data points at round {retries}.",
                )
                if "idx" in dataset[0]:
                    print(f"Failed indexes: {[item['idx'] for item in dataset]}")

        if len(failed_results):
            for result in failed_results:
                self.write_to_file(output_file, result)
            print(
                f"Saved {len(failed_results)} failed cases to the end of {output_file}."
            )

    @classmethod
    def register_args(
        cls,
        parser: ArgumentParser,
        prefix: str = None,
        shortprefix: str = None,
        default_num_retries: int = 0,
        default_batch_size: int = 5,
        default_timeout: int = 3000,
    ):
        """Register arguments for the dataloader."""
        if prefix is None:
            prefix = ""
            shortprefix = ""
        else:
            prefix += "-"
            shortprefix = shortprefix or prefix[0]
        parser.add_argument(
            f"--{prefix}num-retries",
            f"-{shortprefix}nr",
            type=int,
            default=default_num_retries,
            help=(
                "The number of retries when the function execution fails. "
                "Do not set a too large number as it may take a long time to finish."
            ),
        )
        parser.add_argument(
            f"--{prefix}batch-size",
            f"-{shortprefix}bs",
            type=int,
            default=default_batch_size,
            help="The number of the function executions in parallel",
        )
        parser.add_argument(
            f"--{prefix}threading-request-timeout",
            f"-{shortprefix}trt",
            type=int,
            default=default_timeout,
            help="Timeout for a single function execution request",
        )
        parser.add_argument(
            f"--{prefix}sequential-generation",
            f"-{shortprefix}sg",
            action="store_true",
            help="Run function sequentially over the dataset instead of in parallel",
        )
        parser.add_argument(
            f"--{prefix}disable-progress",
            f"-{shortprefix}xprog",
            action="store_true",
            help="Disable progress bar for multi-threaded execution",
        )
        parser.add_argument(
            f"--{prefix}disable-timer",
            f"-{shortprefix}xtimer",
            action="store_true",
            help="Disable timer for multi-threaded execution",
        )

    @classmethod
    def from_args(
        cls,
        args: Namespace,
        prefix: str = None,
        **kwargs,
    ):
        """Create an executor from arguments."""
        if prefix is None:
            prefix = ""
        else:
            prefix += "_"
        args_dict = {}
        for name in [
            "num_retries",
            "batch_size",
            "threading_request_timeout",
            "sequential_generation",
            "disable_progress",
            "disable_timer",
        ]:
            args_dict[name] = getattr(args, prefix + name)
        args_dict.update(kwargs)
        return cls(**args_dict)
