import asyncio, nest_asyncio
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import time
import numpy as np
from tqdm import tqdm

import ray
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput, CompletionOutput

from thinker_task.exp_engine.accelerators.inference.vllm_engine import LLMActor
from thinker_task.exp_engine.accelerators.inference.utils import parallel_f, filter_concat, filter_list, filter_fill

class MultiCompletionOutput(CompletionOutput):
    """A subclass of CompletionOutput."""
    
    def __init__(
        self,
        answer_status: Optional[int] = None,
        attempt_used: Optional[float] = None,
        attempt_remaining: Optional[float] = None,
        final_answer: Optional[str] = "",
        *args,  # Keep positional arguments for RequestOutput
        **kwargs  # Keep keyword arguments for RequestOutput
    ) -> None:
        super().__init__(*args, **kwargs)
        self.answer_status = answer_status
        self.attempt_remaining = attempt_remaining
        self.attempt_used = attempt_used    
        self.final_answer = final_answer    

    def __repr__(self) -> str:
        return super().__repr__() + f", answer_status={self.answer_status}), attempt_used={self.attempt_used}, attempt_remaining={self.attempt_remaining}, final_answer={self.final_answer}"

class MulLLMActor(LLMActor):
    def __init__(self, 
                 *args,                  
                 min_attempt=1, 
                 max_attempt=3, 
                 repeat_question=False, 
                 num_workers=16, 
                 prompt_template=None, 
                 prompt_wrong_template=None,      
                 prompt_type=0,                             
                 **kwargs):   
             
        os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
        kwargs["enable_prefix_caching"] = True
        super().__init__(*args, **kwargs)
        
        self._tokenizer = self.llm.get_tokenizer()
        self.num_workers = num_workers
        self.min_attempt = min_attempt
        self.max_attempt = max_attempt
        self.mini_batch_size = 256
        self.repeat_question = repeat_question
        self.prompt_type = prompt_type
        self.executor = ThreadPoolExecutor(max_workers=self.num_workers)

        if not prompt_template:
            if prompt_type == 0:
                self.prompt_template = \
"""<|im_start|>System: A conversation between User and Assistant. The User asks a question, and the Assistant solves it. \
The Assistant first thinks about the reasoning process in the mind and then provides the User with \
the answer. The reasoning process is enclosed within <think> </think> and answer is enclosed within \
<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer \
here </answer>. <|im_end|>
<|im_start|>User: You have {attempt_remaining} attempts to answer the question. You must put your answer \
inside <answer> </answer> tags, i.e., <answer> answer here </answer>. And your final answer will be extracted \
automatically by the \\boxed{{}} tag. This is the problem:
{input}
<|im_end|>
<|im_start|>Assistant: <think>"""
            else:
                self.prompt_template = \
"""<|im_start|>User: You have {attempt_remaining} attempts to answer the question: {input} Let's think step by step and output the final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>Assistant: <think>"""
        else:
            self.prompt_template = prompt_template
            assert "{input}" in prompt_template, f"prompt_template should specify question, not {prompt_template}"
            assert "{attempt_remaining}" in prompt_template, f"prompt_template should specify remaining_attempt, not {prompt_template}"            

        if not prompt_wrong_template:
            if not self.repeat_question:
                self.prompt_wrong_template = \
"""<|im_end|>
<|im_start|>User: Your previous answer is wrong. Try alternatives and refine your answer. You have {attempt_remaining} attempts to answer the question.<|im_end|>
<|im_start|>Assistant: <think>"""
            else:
                if prompt_type == 0:
                    self.prompt_wrong_template = \
"""<|im_end|>
<|im_start|>User: Your previous answer is wrong. Try alternatives and refine your answer. You have {attempt_remaining} attempts to answer the question. You must put your answer \
inside <answer> </answer> tags, i.e., <answer> answer here </answer>. And your final answer will be extracted \
automatically by the \\boxed{{}} tag. This is the problem:
{input}
<|im_end|>
<|im_start|>Assistant: <think>"""
                else:
                    self.prompt_wrong_template = \
"""<|im_start|>User: Your previous answer is wrong. Try alternatives and refine your answer. You have {attempt_remaining} attempts to answer the question: {input} Let's think step by step and output the final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>Assistant: <think>"""
        else:
            self.prompt_wrong_template = prompt_wrong_template
            assert "{attempt_remaining}" in prompt_wrong_template, f"prompt_wrong_template should specify remaining_attempt, not {prompt_wrong_template}"
    
    def generate(self, prompts=None, prompt_token_ids=None, sampling_params=None, answers=None, use_tqdm=False): 
        
        if isinstance(prompts, str):
            prompts = [prompts]

        batch_size = len(prompts) if prompts is not None else len(prompt_token_ids)
        ray.logger.info(f"Generate called with batch size {batch_size}; eval {answers is None}")

        mini_batch_size = self.mini_batch_size
        if prompts is not None:
            prompt_batches = [prompts[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]
        else:
            prompt_batches = [None for i in range(0, batch_size, mini_batch_size)]

        if prompt_token_ids is not None:
            prompt_token_ids_batches = [prompt_token_ids[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]
        else:
            prompt_token_ids_batches = [None for i in range(0, batch_size, mini_batch_size)]

        if answers is not None:
            assert len(answers) == batch_size
            answer_batches = [answers[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]

        outputs = []
        start_time = time.time()        
        if use_tqdm: pbar = tqdm(total=batch_size, desc="Generating")           

        for n, (prompt_batch, prompt_token_ids_batch) in enumerate(zip(prompt_batches, prompt_token_ids_batches)):
            #self.llm.reset_prefix_cache()
            answer_batch = answer_batches[n] if answers is not None else None
            outputs.extend(self._generate(prompt_batch, prompt_token_ids=prompt_token_ids_batch, sampling_params=sampling_params, answer=answer_batch))        
            if use_tqdm: pbar.update(mini_batch_size)
        if answers is not None:
            iscorrect = np.array([x.outputs[0].answer_status == 1 for x in outputs], dtype=np.float32)
            acc = np.mean(iscorrect)
            acc_str = f" (acc: {acc*100:.2f})"
        else:
            acc_str = ""
        print(f"Complete processing {batch_size} prompts in {time.time() - start_time:.2f} seconds{acc_str}.")
        #save_debug_data("tmp", prefix="mllm", prompts=prompts, answers=answers, reward=[x.reward for x in outputs], response=[x.outputs[0].text for x in outputs])
        return outputs    
    
    def _generate(self, prompt=None, prompt_token_ids=None, sampling_params=None, answer=None):                
        #answer = ["0"] * (len(prompt) if prompt else len(prompt_token_ids))

        eval = answer is None            
        if prompt is None:
            prompt = self._tokenizer.batch_decode(prompt_token_ids)
        batch_size = len(prompt)

        if sampling_params is None:
            sampling_params = SamplingParams()
        assert sampling_params.n == 1, f"sampling_params.n should be 1 instead of {sampling_params.n}"

        if eval:
            attempt_remaining = np.ones((batch_size,), dtype=np.int64)
        else:
            attempt_remaining = np.random.choice(self.max_attempt + 1 - self.min_attempt, size=batch_size) + self.min_attempt 

        question = prompt
        prompt = [self.prompt_template.format(attempt_remaining=n, input=q) for n, q in zip(attempt_remaining, question)]

        eot = [False] * batch_size  # flags for finish per sample
        full_response = [x for x in prompt]
        final_answer = ['' for _ in range(batch_size)]
        stop_reason = [None] * batch_size
        finish_reason = [None] * batch_size
        answer_status = [0] * batch_size
        attempt_used = [0] * batch_size

        for j in range(self.max_attempt):
            active_input = filter_list(full_response, eot)  
            out = self.llm.generate(active_input, sampling_params, use_tqdm=False)   
            if j == 0: first_out = out

            active_response = [output.outputs[0].text.replace("<|im_start|>", "").replace("<|im_end|>", "") for output in out]               
            full_response = filter_concat(full_response, active_response, eot)
            stop_reason = filter_fill(stop_reason, [output.outputs[0].finish_reason for output in out], eot)
            finish_reason = filter_fill(finish_reason, [output.outputs[0].finish_reason for output in out], eot)            

            if not eval:             
                #ray.logger.info(f"{j}: check answer with length {len(active_response)}")
                check_out = parallel_f(
                    args_list=list(zip(filter_list(answer, eot), active_response, [self.prompt_type]*len(active_response))), 
                    num_workers=self.num_workers,
                    default_value=(0, '')
                    )
                active_answer_status = [x[0] for x in check_out]
                active_final_answer = [x[1] for x in check_out]
                #ray.logger.info(f"{j}: finish check answer with length {len(active_answer_status)}")
                answer_status = filter_fill(answer_status, active_answer_status, eot)
                final_answer = filter_fill(final_answer, active_final_answer, eot)    
            
            active_max_token_reached = []
            for active_idx, output_ in enumerate(out):
                output = output_.outputs[0]
                active_max_token_reached.append(len(output.token_ids) == 0) # max token reached

            active_idx = 0
            new_eot = []
            for idx, finished in enumerate(eot):
                if not finished:
                    answer_status_ = active_answer_status[active_idx] if not eval else 0
                    max_token_reached = active_max_token_reached[active_idx]
                    attempt_remaining[idx] = attempt_remaining[idx] - 1
                    attempt_used[idx] = attempt_used[idx] + 1
                    if max_token_reached:
                        # max token reached, meaning agent does not give any response
                        new_eot.append(True)                       
                    elif answer_status_ in [0, 2] and attempt_remaining[idx] > 0: 
                        # wrong answer and new attempt is granted                                                
                        wrong_response = self.prompt_wrong_template.format(attempt_remaining=attempt_remaining[idx], input=question[idx])
                        full_response[idx] += wrong_response                        
                        new_eot.append(False)
                    elif answer_status_ == 1: 
                        # correct answer
                        answer_status[idx] = 1
                        new_eot.append(True)
                    else:
                        # wrong answer and attempt exhausted
                        new_eot.append(True)        

                    active_idx += 1            
                else:
                    new_eot.append(True)

            eot = new_eot        
            if all(eot): break 

        request_outputs = []

        for i in range(batch_size):  
            prompt_i = prompt[i]
            text_i = full_response[i][len(prompt_i):]
            completion_output = MultiCompletionOutput(
                index=0,
                text=text_i,
                token_ids=self._tokenizer.encode(text_i, add_special_tokens=True),
                cumulative_logprob=None,
                logprobs=None,  # TODO
                finish_reason=finish_reason[i],
                stop_reason=stop_reason[i],
                answer_status=answer_status[i],
                attempt_remaining=attempt_remaining[i],
                attempt_used=attempt_used[i],
                final_answer=final_answer[i],
            )
            reequest_output = RequestOutput(
                request_id = first_out[i].request_id,
                prompt = first_out[i].prompt,
                prompt_token_ids = first_out[i].prompt_token_ids,
                prompt_logprobs=None,
                outputs=[completion_output],
                finished=finished,
                )
            request_outputs.append(reequest_output)

        return request_outputs

if __name__ == "__main__":
    model = "Qwen/Qwen2.5-1.5B"
    #model = "/root/xx_workspace/huggingface_cache/hub/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796"
    llm = MulLLMActor(model, min_attempt=3, repeat_question=True, tensor_parallel_size=1, prompt_type=1)
    sampling_params = SamplingParams(n=1, temperature=0.6, max_tokens=1000)
    prompts = ["If $x + y = 1$ and $x - y = 3$, what is the value of $x - 2y$?",
            "If $\\sqrt{5 + x} + \\sqrt{20 - x} = 7$, what is the value of $(5 + x)(20 - x)$?",
            "Let $d_1 = a^2 + 2^a + a \cdot 2^{(a+1)/2}$ and $d_2 = a^2 + 2^a - a \cdot 2^{(a+1)/2}$. If $1 \le a \le 251$, how many integral values of $a$ are there such that $d_1 \cdot d_2$ is a multiple of $5$?"
            ]
    answers = ["4", "144", "50"]
    request_outputs = llm.generate(prompts=prompts, sampling_params=sampling_params, answers=answers)

    #prompt_token_ids = llm._tokenizer(prompts,padding=False)["input_ids"]
    #request_outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, answers=answers)    
    print(request_outputs)