import os
import random
import time
import torch
import json
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError
from typing import Any, Callable, List, Optional, Tuple

from pebble import ProcessPool, ThreadPool
# from pebble.common import ProcessExpired
from ray.util import pdb

# process_timeout = 30 
# thread_timeout = 60 # JINA很慢

# def split_tasks_by_type(list_of_kwargs: List[dict]) -> Tuple[List[dict], List[dict]]:
#     """将任务分为需要进程处理和需要线程处理的两类"""
#     process_tasks = []
#     thread_tasks = []
    
#     for kwargs in list_of_kwargs:
#         if kwargs.get('tool_name') in ['PythonInterpreter'] or 'interpreter' in kwargs.get('tool_name').lower():
#             process_tasks.append(kwargs)
#         else:
#             thread_tasks.append(kwargs)
            
#     return process_tasks, thread_tasks


# import threading
# import psutil

# def monitor_memory(pid, limit_gb, timeout, callback):
#     try:
#         proc = psutil.Process(pid)
#         start_time = time.time()
#         while time.time() - start_time < timeout:
#             mem_gb = proc.memory_info().rss / (1024 ** 3)
#             if mem_gb > limit_gb:
#                 proc.terminate()
#                 callback(pid, f"Memory limit exceeded: {mem_gb:.2f} GB")
#                 break
#             if not proc.is_running():
#                 break
#             time.sleep(0.2)
#     except psutil.NoSuchProcess:
#         pass

# def on_memory_violation(pid, msg):
#     print(f"[MONITOR] PID {pid}: {msg}")

# def parallel_exec(
#     fn: Callable,
#     list_of_kwargs: List[dict],
#     max_workers: Optional[int] = None,
#     jitter: float = 0.0,
# ) -> list:
#     process_tasks, thread_tasks = split_tasks_by_type(list_of_kwargs)
    
#     all_results = []
    
#     with ProcessPool(max_workers=min(max_workers * 2, os.cpu_count())) as process_pool, \
#          ThreadPool(max_workers=max_workers) as thread_pool:
        
#         process_futures = []
#         for kwargs in process_tasks:
#             kwargs = deepcopy(kwargs)
#             index = kwargs['index']
#             t_index = kwargs['t_index']
#             future = process_pool.schedule(fn, kwargs=kwargs, timeout=process_timeout)
#             process_futures.append((future, index, t_index))
#             if jitter > 0.0:
#                 time.sleep(jitter * random.random())
        
#         thread_futures = []
#         for kwargs in thread_tasks:
#             kwargs = deepcopy(kwargs)
#             index = kwargs.get('index')
#             t_index = kwargs.get('t_index')
#             future = thread_pool.schedule(fn, kwargs=kwargs, timeout=thread_timeout)
#             thread_futures.append((future, index, t_index))
#             if jitter > 0.0:
#                 time.sleep(jitter * random.random())
        
#         # 收集线程结果
#         for future, index, t_index in thread_futures:
#             try:
#                 result = future.result()
#                 all_results.append(result)
#             except TimeoutError:
#                 all_results.append((
#                     index,
#                     t_index,
#                     {
#                         "success": False,
#                         "error_message": f"Thread task (index={index}) timed out after {thread_timeout} seconds"
#                     }
#                 ))
#             except Exception as e:
#                 all_results.append((
#                     index,
#                     t_index,
#                     {
#                         "success": False,
#                         "error_message": f"Thread task (index={index}) failed: {str(e)}"
#                     }
#                 ))
        
#         # 收集进程结果
#         for future, index, t_index in process_futures:
#             try:
#                 result = future.result()
#                 all_results.append(result)
#             except TimeoutError:
#                 all_results.append((
#                     index,
#                     t_index,
#                     {
#                         "success": False,
#                         "error_message": f"Task (index={index}) timed out after {process_timeout} seconds"
#                     }
#                 ))
#             except Exception as e:
#                 all_results.append((
#                     index,
#                     t_index,
#                     {
#                         "success": False,
#                         "error_message": f"Task (index={index}) failed with error: {str(e)}"
#                     }
#                 ))
    
#     return all_results


# def parallel_exec(
#     fn: Callable,
#     list_of_kwargs: List[dict],
#     max_workers: Optional[int] = None,
#     jitter: float = 0.0,
# ) -> list:
#     process_tasks, thread_tasks = split_tasks_by_type(list_of_kwargs)
    
#     all_results = []
    
#     with ProcessPool(max_workers=min(max_workers // 2, os.cpu_count())) as process_pool, \
#          ThreadPoolExecutor(max_workers=max_workers) as thread_pool:
        
#         process_futures = []
#         if process_tasks:
#             for kwargs in process_tasks:
#                 kwargs = deepcopy(kwargs)
#                 index = kwargs['index']
#                 t_index = kwargs['t_index']
#                 future = process_pool.schedule(fn, kwargs=kwargs, timeout=process_timeout)
#                 process_futures.append((future, index, t_index))
#                 if jitter > 0.0:
#                     time.sleep(jitter * random.random())
        
#         thread_futures = []
#         if thread_tasks:
#             for kwargs in thread_tasks:
#                 thread_futures.append(thread_pool.submit(fn, **kwargs))
#                 if jitter > 0.0:
#                     time.sleep(jitter * random.random())
        
#         # 收集进程结果
#         for future, index, t_index in process_futures:
#             try:
#                 result = future.result()
#                 all_results.append(result)
#             except TimeoutError:
#                 all_results.append((
#                     index,
#                     t_index,
#                     json.dumps({
#                         "success": False,
#                         "error_message": f"[Python Interpreter Error]: timed out after {process_timeout} seconds"
#                     })
#                 ))
#             except Exception as e:
#                 all_results.append((
#                     index,
#                     t_index,
#                     json.dumps({
#                         "success": False,
#                         "error_message": f"[Python Interpreter Error]: failed with error: {str(e)}"
#                     })
#                 ))
        
#         # 收集线程结果
#         for future in as_completed(thread_futures):
#             all_results.append(future.result())
    
#     return all_results

# def parallel_exec(
#     fn: Callable,
#     list_of_kwargs: List[dict],
#     max_workers: Optional[int] = None,
#     jitter: float = 0.0,
# ) -> list:
#     """
#     Executes a given function `fn` in parallel, using multiple threads, on a list of argument tuples.
#     The function limits the number of concurrent executions to `max_workers` and processes tasks in chunks,
#     pausing between each chunk to avoid hitting rate limits or quotas.

#     Args:
#     - fn (Callable): The function to execute in parallel.
#     - list_of_kwargs (list): A list of dicts, where each dict contains arguments for a single call to `fn`.
#     - max_workers (int, optional): The maximum number of threads that can be used to execute the tasks
#       concurrently.
#     - jitter (float, optional): Wait for jitter * random.random() before submitting the next job.

#     Returns:
#     - A list containing the results of the function calls. The order of the results corresponds to the order
#       the tasks were completed, which may not necessarily be the same as the order of `list_of_kwargs`.

#     """
#     pdb.set_trace()
#     process_tasks, thread_tasks = split_tasks_by_type(list_of_kwargs)

#     process_results = []
#     if process_tasks:
#         with ProcessPool(max_workers=max_workers) as pool:
#             future = pool.map(fn, process_tasks, timeout=process_timeout)
#             try:
#                 for result in future.result():
#                     process_results.append(result)
#             except (TimeoutError, Exception) as error:
#                 task_kwargs = error.args[1]
#                 index = task_kwargs['index']
#                 t_index = task_kwargs['t_index']
#                 error_message = (
#                     f"Task (index={index}) timed out after {process_timeout} seconds"
#                     if isinstance(error, TimeoutError)
#                     else f"Task (index={index}) failed with error: {str(error)}"
#                 )
#                 process_results.append((
#                     index,
#                     t_index,
#                     {
#                         "success": False,
#                         "error_message": error_message
#                     }
#                 ))

#     thread_results = []
#     if thread_tasks:
#         with ThreadPoolExecutor(max_workers=max_workers) as executor:
#             # Get the tasks for the current chunk
#             futures = []
#             for kwargs in thread_tasks:
#                 futures.append(executor.submit(fn, **kwargs))
#                 if jitter > 0.0:
#                     time.sleep(jitter * random.random())
#             for future in as_completed(futures): # 结果顺序取决于哪个先完成
#                 thread_results.append(future.result())

#     return process_results + thread_results

def parallel_exec(
    fn: Callable,
    list_of_kwargs: List[dict],
    max_workers: Optional[int] = None,
    jitter: float = 0.0,
) -> list:
    """
    Executes a given function `fn` in parallel, using multiple threads, on a list of argument tuples.
    The function limits the number of concurrent executions to `max_workers` and processes tasks in chunks,
    pausing between each chunk to avoid hitting rate limits or quotas.

    Args:
    - fn (Callable): The function to execute in parallel.
    - list_of_kwargs (list): A list of dicts, where each dict contains arguments for a single call to `fn`.
    - max_workers (int, optional): The maximum number of threads that can be used to execute the tasks
      concurrently.
    - jitter (float, optional): Wait for jitter * random.random() before submitting the next job.

    Returns:
    - A list containing the results of the function calls. The order of the results corresponds to the order
      the tasks were completed, which may not necessarily be the same as the order of `list_of_kwargs`.

    """
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Get the tasks for the current chunk
        futures = []
        for kwargs in list_of_kwargs:
            futures.append(executor.submit(fn, **kwargs))
            if jitter > 0.0:
                time.sleep(jitter * random.random())
        for future in as_completed(futures): # 结果顺序取决于哪个先完成
            results.append(future.result())
    return results


# for debug
def serial_exec(fn: Callable, list_of_kwargs: List[dict]) -> List[Any]:
    results = []
    rank = torch.distributed.get_rank()
    for idx, kwargs in enumerate(list_of_kwargs):
        log_kwargs = deepcopy(kwargs)
        log_kwargs.pop('messages', [])
        # print(f'[serial] (rank-{rank}) task {idx}/{len(list_of_kwargs)} input {log_kwargs}\n', flush=True)
        result = fn(**kwargs)
        # print(f'[serial] (rank-{rank}) task {idx}/{len(list_of_kwargs)} input -> output {log_kwargs} -> {result}\n', flush=True)
        results.append(result)
    return results
