import time
import functools
import signal
import re
import ray
import numpy as np
from thinker_task.ppo.tools.math_utils import solution2answer, is_equiv, _is_latex_equal, extract_boxed_answer
from vllm import SamplingParams

def extract_answer(response: str, extract_type: int=1):
    #pattern = re.compile(r"(\\boxed{.*})")
    if extract_type == 0:
        pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
        matches = re.findall(pattern, response)
        final_answer = matches[-1] if matches else ""
    elif extract_type == 1:
        final_answer = extract_boxed_answer(response)
    else:
        final_answer = response
    return final_answer

def check_answer(answer: str, response: str, extract_type: int=1):           
    final_answer = extract_answer(response, extract_type)
    
    if final_answer.lower() in ["yes", "no"]:
        final_answer = final_answer.lower()

    if not answer or not final_answer: return 0, final_answer

    if answer.lower() in ["yes", "no"]:
        if answer.lower() == final_answer.lower():
            return 1, final_answer.lower()
        elif final_answer.lower() in ["yes", "no"]:
            return 2, final_answer.lower()
        else:
            return 0, final_answer.lower()

    answer = solution2answer(answer)
    final_answer = solution2answer(final_answer)    
    if is_equiv(answer, final_answer): 
        return 1, final_answer
    
    euqal_result = _is_latex_equal(answer, final_answer)
    if euqal_result:
        return 1, final_answer
    elif final_answer:
        return 2, final_answer
    else:
        return 0, final_answer    

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException()

def timeout_decorator(seconds, default_value):
    """
    A decorator that raises a TimeoutException (and returns default_value)
    if the decorated function runs longer than 'seconds'.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Set the alarm signal handler
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(seconds)
            try:
                return func(*args, **kwargs)
            except TimeoutException:
                print("timeout occured with args:", args, kwargs)
                return default_value
            finally:
                signal.alarm(0)  # Disable the alarm
        return wrapper
    return decorator

@ray.remote(num_cpus=0.25)
def execute_task(arg, timeout=10, default_value=(0, ''), f=None):    
    if f is None: f = check_answer
    # Wrap f with the timeout logic
    #f_ = timeout_decorator(timeout, default_value)(f)
    f_ = f
    if isinstance(arg, dict):
        return f_(**arg)
    elif isinstance(arg, tuple):
        return f_(*arg)
    else:
        return f_(arg)

def parallel_f(args_list, num_workers=16, timeout=60, default_value=(0, ''), f=None):
    """
    Runs function f on each element of args_list via Ray remote tasks.
    No more than num_workers tasks will be running concurrently.
    Each task applies an internal timeout (in seconds). If a task exceeds
    the timeout, it is cancelled and default_value is used.
    
    Additionally, the elapsed time for each task is recorded.
    
    Returns:
        results: a list of the function results (or default_value if timed out)
        times: a list of elapsed times for each task
    """
    if len(args_list) == 0:
        return [], []
    
    results = [None] * len(args_list)
    pending_tasks = list(enumerate(args_list))  # Each element: (index, argument)
    # running_tasks maps future -> (index, start_time)
    running_tasks = {}
    timeout_tasks = {}

    while pending_tasks or running_tasks:
        # Launch new tasks while we haven't hit the concurrency limit.
        while len(running_tasks) < num_workers and pending_tasks:
            idx, arg = pending_tasks.pop(0)
            future = execute_task.remote(arg=arg, timeout=timeout, default_value=default_value, f=f)
            running_tasks[future] = (idx, time.time())

        current_time = time.time()

        to_cancel = []
        for future, (idx, start_time) in running_tasks.items():
            elapsed = current_time - start_time
            if elapsed > timeout:
                if idx not in timeout_tasks:
                    timeout_tasks[idx] = current_time
                    ray.logger.info(f"Timeout occured with idx {idx} ({elapsed:.1f}s) args {args_list[idx]}")
                elif current_time - timeout_tasks[idx] > timeout:
                    timeout_tasks[idx] = current_time
                    ray.logger.info(f"Continue timeout occured with idx {idx}) ({elapsed:.1f}s) args {args_list[idx]}")
            
            if elapsed > timeout * 5:            
                to_cancel.append(future)
                ray.logger.info(f"Forceful termination with idx {idx} ({elapsed:.1f}s) args {args_list[idx]}")

        for future in to_cancel:
            idx, start_time = running_tasks.pop(future)
            # Cancel the task (forcefully, if necessary), wrapped in a try/except
            try:
                ray.cancel(future, force=True)
            except Exception:
                pass
            results[idx] = default_value

        # Poll for finished tasks.
        if running_tasks:
            done, _ = ray.wait(list(running_tasks.keys()), num_returns=len(running_tasks), timeout=0.1)
            for future in done:
                # Check if it wasn't already cancelled.
                if future in running_tasks:
                    idx, start_time = running_tasks.pop(future)
                    try:
                        results[idx] = ray.get(future)
                    except Exception as e:
                        ray.logger.info(e)
                        results[idx] = default_value                    

        # Sleep briefly to avoid busy-waiting.
        time.sleep(0.01)

    return results

def filter_concat(x, y, eot, unsqueeze=False):
    # only concat y[i] to x[j] if not eot[j]
    i = 0
    if isinstance(eot, list): eot = np.array(eot, dtype=np.bool_)
    if isinstance(y, list) and len(y) > 0:
        assert len(y) == sum(~eot), f"len(y): {len(y)}, sum(~eot): {sum(~eot)}; y: {y}, eot: {eot}"
    for j, finished in enumerate(eot):
        if not finished:
            y_ = (y[i] if isinstance(y, list) and len(y) > 0 else y)
            if unsqueeze:
                y_ = [y_]
            x[j] = x[j] + y_
            i += 1
    return x

def unsqueeze(x):
    return list([[k] for k in x])

def join_list(x, start_idx=None):
    if start_idx is None:
        return [sum(k, []) for k in x]
    else:
        return [sum(k[start_idx:], []) if len(k) > start_idx else [] for k in x]

def filter_list(x, eot):
    return [inp for inp, finished in zip(x, eot) if not finished]

def filter_fill(x, y, eot):
    i = 0
    if isinstance(eot, list): eot = np.array(eot, dtype=np.bool_)
    assert len(y) == sum(~eot), f"len(y): {len(y)}, sum(~eot): {sum(~eot)}; y: {y}, eot: {eot}"
    for j, finished in enumerate(eot):
        if not finished:
            x[j] = y[i]
            i += 1
    return x

def encode_prompts(prompts, tokenizer, padding=False, add_special_tokens=False, **kwargs):
    """
    Tokenizes prompts using the provided tokenizer.
    
    If prompts is a list of strings, tokenizes them in one batch.
    If prompts is a list of list of strings, tokenizes each sublist sequentially:
      - For each sublist, the tokenizer is called on the next unprocessed string
        (turn) across all sublists (batched).
      - The token IDs for each turn are concatenated into a final list per sublist.
    """
    # Case 1: prompts is a list of strings.
    if not prompts:
        return []
    if isinstance(prompts[0], str):
        return tokenizer(prompts, padding=padding, add_special_tokens=add_special_tokens, **kwargs)["input_ids"]
    
    # Case 2: prompts is a list of list of strings.
    results = [[] for _ in range(len(prompts))]
    pointers = [0] * len(prompts)

    # assert all elements are lists
    if not all(isinstance(p, list) for p in prompts):
        raise ValueError("All elements of prompts must be lists of strings.")
    
    # Process each "turn" across the sublists in batch
    while any(pointers[i] < len(prompts[i]) for i in range(len(prompts))):
        batch = []
        batch_indices = []  # track which sublist each prompt belongs to
        for i, pointer in enumerate(pointers):
            if pointer < len(prompts[i]):
                batch.append(prompts[i][pointer])
                batch_indices.append(i)
        
        # Tokenize the current batch of turns.
        tokenized = tokenizer(batch, padding=padding, add_special_tokens=add_special_tokens, **kwargs)["input_ids"]
        
        # Append the tokenized result to the corresponding sublist.
        for idx, tokens in zip(batch_indices, tokenized):
            results[idx].extend(tokens)
            pointers[idx] += 1
    return results

def clone_sampling_params(sampling_params, **kwargs):
    # For some reasons, simply cloning sample_params then modify the property will not work
    internal_fields = {"_real_n", "output_text_buffer_length", "_all_stop_token_ids"}
    params_dict = {
        field: getattr(sampling_params, field)
        for field in sampling_params.__annotations__
        if field not in internal_fields and not field.startswith("_")
    }
    params_dict.update(kwargs)
    return SamplingParams.from_optional(**params_dict)