import asyncio, nest_asyncio
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import time
import numpy as np
from packaging import version
from tqdm import tqdm

import ray
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput, CompletionOutput

from thinker_task.exp_engine.accelerators.inference.vllm_engine import LLMActor
from thinker_task.exp_engine.accelerators.inference.utils import parallel_f, filter_concat, filter_list, filter_fill, unsqueeze, join_list, clone_sampling_params

class SumCompletionOutput(CompletionOutput):
    """A subclass of CompletionOutput."""
    
    def __init__(
        self,
        response_status: Optional[List[int]] = None,   
        rewards: Optional[List[float]] = None,        
        answer_status: Optional[int] = None,
        all_answers: Optional[List[str]] = None,
        final_answer: Optional[str] = "",        
        *args,  # Keep positional arguments for RequestOutput
        **kwargs  # Keep keyword arguments for RequestOutput
    ) -> None:
        super().__init__(*args, **kwargs)
        # -1 for na, 0 for wrong format, 1 for correct ans, 2 for wrong ans but correct format
        self.response_status = response_status 
        self.answer_status = answer_status
        self.rewards = rewards
        self.final_answer = final_answer
        self.all_answers = all_answers

    def __repr__(self) -> str:
        return super().__repr__() + f", response_status={self.response_status}, answer_status={self.answer_status}, rewards={self.rewards}, final_answer={self.final_answer}"

class SumLLMActor(LLMActor):
    def __init__(self, 
                 *args,                  
                 num_workers=8,      
                 prompt_type=0,       
                 summary_min_token=300,  
                 summary_max_token=1000,  
                 verify_max_token=6000,
                 slow_max_token=6000,    
                 reward_right_format=0.25,             
                 summary_temperature=0.6, 
                 summary_skip=False,
                 verify_skip=False,
                 summary_reward_coef=1.,
                 fast_reward_coef=1.,
                 summary_nonstop_discount=1.,
                 **kwargs):   
             
        os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
        kwargs["enable_prefix_caching"] = True

        model_name = args[0]
        super().__init__(*args, **kwargs)
        
        self._tokenizer = self.llm.get_tokenizer()
        self.num_workers = num_workers
        self.mini_batch_size = 256
        self.prompt_type = prompt_type        
        self.summary_temperature = summary_temperature     
        self.summary_reward_coef = summary_reward_coef
        self.fast_reward_coef = fast_reward_coef
        self.summary_nonstop_discount = summary_nonstop_discount
        self.reward_right_format = reward_right_format

        self.max_model_len = kwargs.get("max_model_len", 8000)        
        self.summary_min_token = summary_min_token 
        self.summary_max_token = summary_max_token 
        self.verify_max_token = verify_max_token
        self.slow_max_token = slow_max_token
        self.summary_skip = summary_skip
        self.verify_skip = verify_skip
        assert self.max_model_len >= self.summary_max_token * 2 + self.verify_max_token + self.slow_max_token, f"max_model_len {self.max_model_len} should be larger than summary_max_token {self.summary_max_token} * 2 + verify_max_token {self.verify_max_token} + slow_max_token {self.slow_max_token}"        
        assert self.prompt_type == 0

        self.executor = ThreadPoolExecutor(max_workers=self.num_workers)

        self.fast_template = \
"""<|im_start|>User: Answer the below question with concise steps and output the final answer within \\boxed{{}}. \
 Limit your response below {summary_max_token} words. This is the problem:
{input}
<|im_end|>
<|im_start|>Assistant: """
        self.verify_template = \
"""<|im_end|>
<|im_start|>User: Is your answer above correct? Please verify each step and the answer carefully. \
Output \\boxed{{Yes}} if your answer is correct, or \\boxed{{No}} if your answer is incorrect.
<|im_end|>
<|im_start|>Assistant: """
        self.verify_sum_template = self.fast_template + "{response}" + self.verify_template
        self.slow_template_vmode = \
"""<|im_end|>
<|im_start|>User: Based on the above discussion, think carefully and solve the problem again but with a slow reasoning process. \
The reasoning process should be enclosed within <think> </think>. This is the problem:
{input}
Let's think step by step and output the final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>Assistant: <think>"""
        self.slow_template = \
"""<|im_end|>
<|im_start|>User: Your initial answer is incorrect. Now, think about the possible errors and consider alternative solutions. \
The reasoning process should be enclosed within <think> </think>. This is the problem:
{input}
Let's think step by step and output the final answer within \\boxed{{}}.
<|im_end|>
<|im_start|>Assistant: <think>"""  
        self.summary_template = \
"""<|im_end|>
<|im_start|>User: Your final answer is correct. Now summarize the steps leading to your final answer concisely and precisely, excluding \
internal reasoning. Limit your response between {summary_min_token} and {summary_max_token} words. This is the problem:
{input}
<|im_end|>
<|im_start|>Assistant: """
        
    def generate(self, prompts=None, prompt_token_ids=None, sampling_params=None, answers=None, use_tqdm=False, **kwargs): 
        
        if isinstance(prompts, str):
            prompts = [prompts]

        batch_size = len(prompts) if prompts is not None else len(prompt_token_ids)
        ray.logger.info(f"Generate called with batch size {batch_size}; eval {answers is None}")

        mini_batch_size = self.mini_batch_size
        if prompts is not None:
            prompt_batches = [prompts[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]
        else:
            prompt_batches = [None for i in range(0, batch_size, mini_batch_size)]

        if prompt_token_ids is not None:
            prompt_token_ids_batches = [prompt_token_ids[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]
        else:
            prompt_token_ids_batches = [None for i in range(0, batch_size, mini_batch_size)]

        if answers is not None:
            assert len(answers) == batch_size
            answer_batches = [answers[i:i+mini_batch_size] for i in range(0, batch_size, mini_batch_size)]

        outputs = []
        start_time = time.time()        
        if use_tqdm: pbar = tqdm(total=batch_size, desc="Generating")           

        for n, (prompt_batch, prompt_token_ids_batch) in enumerate(zip(prompt_batches, prompt_token_ids_batches)):
            if version.parse(self.__version__) >= version.parse("0.7"):
                self.llm.reset_prefix_cache()
            answer_batch = answer_batches[n] if answers is not None else None
            outputs.extend(self._generate(prompt_batch, prompt_token_ids=prompt_token_ids_batch, sampling_params=sampling_params, answer=answer_batch, **kwargs))        
            if use_tqdm: pbar.update(mini_batch_size)

        if answers is not None:
            iscorrect = np.array([x.outputs[0].answer_status == 1 for x in outputs], dtype=np.float32)
            acc = np.mean(iscorrect)
            acc_str = f" (acc: {acc*100:.2f})"
        else:
            acc_str = ""
        ray.logger.info(f"Complete processing {batch_size} prompts in {time.time() - start_time:.2f}s{acc_str}.")
        #save_debug_data("tmp", prefix="mllm", prompts=prompts, answers=answers, reward=[x.reward for x in outputs], response=[x.outputs[0].text for x in outputs])
        return outputs    
    
    def _generate(self, prompt=None, prompt_token_ids=None, sampling_params=None, answer=None, fast=False, verify_sum=False):    

        if prompt is None:
            prompt = self._tokenizer.batch_decode(prompt_token_ids)
        question = prompt
        batch_size = len(prompt)

        eval = answer is None    
        if answer is None: answer = [None] * batch_size

        if sampling_params is None:
            sampling_params = SamplingParams()
        assert sampling_params.n == 1, f"sampling_params.n should be 1 instead of {sampling_params.n}"

        fast_answer_status = np.full(batch_size, -1, dtype=np.int64)
        slow_answer_status = np.full(batch_size, -1, dtype=np.int64)
        answer_status = np.full(batch_size, -1, dtype=np.int64)
        verify_status = np.full(batch_size, -1, dtype=np.int64)
        summary_status = np.full(batch_size, -1, dtype=np.int64)
        verify_summary_status = np.full(batch_size, -1, dtype=np.int64)
        final_answer = ["" for _ in range(batch_size)]

        all_answers = [[] for _ in range(batch_size)]
        self.full_prompt_ids = [[] for _ in range(batch_size)]
        self.full_prompts = [[] for _ in range(batch_size)]
        self.full_response_types = [[] for _ in range(batch_size)]
        self.full_stop_reasons = [[] for _ in range(batch_size)]
        self.eot = np.zeros(batch_size, dtype=bool)

        start_idx = None

        # step 1 - fast answer
        fast_sampling_params = clone_sampling_params(sampling_params, max_tokens=self.summary_max_token)
        fast_response = self.inner_generate(fast_sampling_params, self.fast_template, input=question, summary_max_token=self.summary_max_token) 

        check_out = parallel_f(
            args_list=list(zip(filter_list(answer, self.eot), fast_response, [1]*len(fast_response))), 
            num_workers=self.num_workers,
        )
        assert len(check_out) == np.sum(~self.eot).item(), f"check_out {check_out} should be same as sum(~self.eot) {self.eot}"
        fast_answer_status = filter_fill(fast_answer_status, [x[0] for x in check_out], self.eot)
        answer_status = filter_fill(answer_status, [x[0] for x in check_out], self.eot)
        final_answer = filter_fill(final_answer, [x[1] for x in check_out], self.eot)
        all_answers = filter_concat(all_answers, [x[1] for x in check_out], self.eot, unsqueeze=True)

        fast_response_ = ["" for _ in range(batch_size)]
        fast_response = filter_fill(fast_response_, fast_response, self.eot) # make it having full batch size

        if fast:
            self.eot = np.ones(batch_size, dtype=bool)

        # step 2 - verify
        if not np.all(self.eot):
            verify_sampling_params = clone_sampling_params(sampling_params, max_tokens=self.verify_max_token)
            if self.verify_skip:
                skip = np.ones(batch_size, dtype=np.bool_)
            else:
                skip = None
            verify_response = self.inner_generate(verify_sampling_params, self.verify_template, input=question, answer=final_answer, response=fast_response, skip=skip, start_idx=start_idx)

            verify_gt_answer = ["yes" if x == 1 else "no" for x in fast_answer_status]
            check_out = parallel_f(
                args_list=list(zip(verify_gt_answer, verify_response, [1]*len(verify_response))), 
                num_workers=self.num_workers,
            )
            assert len(check_out) == np.sum(~self.eot).item(), f"check_out {check_out} should be same as sum(~self.eot) {self.eot}"
            verify_status = filter_fill(verify_status, [x[0] for x in check_out], self.eot)
            assert not np.any(verify_status[~self.eot] == -1), f"verify_status {verify_status} should not contain -1; eot {self.eot}; answer {answer}"
            verify_answer = [x[1] for x in check_out]
            all_answers = filter_concat(all_answers, [x[1] for x in check_out], self.eot, unsqueeze=True)

            for i in range(batch_size):
                if verify_status[i] == 2 and verify_answer[i] not in ["yes", "no"]:                    
                    verify_status[i] = 0

            full_verify_answer = ["" for _ in range(batch_size)]
            full_verify_answer = filter_fill(full_verify_answer, verify_answer, self.eot) # make it having full batch size

            if self.verify_skip:
                slow_template = self.slow_template_vmode
            else:
                if not eval:
                    self.eot = fast_answer_status == 1
                else:
                    self.eot = filter_fill(self.eot, np.array([x == "yes" for x in verify_answer], dtype=np.bool_), self.eot.copy())
                slow_template = self.slow_template

        # step 3 - slow answer
        if not np.all(self.eot):
            slow_sampling_params = clone_sampling_params(sampling_params, max_tokens=self.slow_max_token)            
            slow_response = self.inner_generate(slow_sampling_params, slow_template, input=question, answer=final_answer, response=fast_response, start_idx=start_idx) 

            check_out = parallel_f(
                args_list=list(zip(filter_list(answer, self.eot), slow_response, [1]*len(slow_response))), 
                num_workers=self.num_workers,
            )
            assert len(check_out) == np.sum(~self.eot).item(), f"check_out {check_out} should be same as sum(~self.eot) {self.eot}"
            slow_answer_status = filter_fill(slow_answer_status, [x[0] for x in check_out], self.eot)
            assert not np.any(slow_answer_status[~self.eot] == -1), f"slow_answer_status {slow_answer_status} should not contain -1; eot {self.eot}; answer {answer}"
            answer_status = filter_fill(answer_status, [x[0] for x in check_out], self.eot)
            final_answer = filter_fill(final_answer, [x[1] for x in check_out], self.eot)        
            all_answers = filter_concat(all_answers, [x[1] for x in check_out], self.eot, unsqueeze=True)

            if eval and verify_sum:
                pass
            elif eval or self.summary_skip:
                self.eot = np.ones(batch_size, dtype=bool)
            else:
                self.eot = filter_fill(self.eot, [x[0] != 1 for x in check_out], self.eot)

        # step 4 - summary
        if not np.all(self.eot):
            summary_sampling_params = clone_sampling_params(sampling_params, max_tokens=self.summary_max_token, temperature=self.summary_temperature)           
            summary_response = self.inner_generate(summary_sampling_params, self.summary_template, input=question, summary_min_token=self.summary_min_token, summary_max_token=self.summary_max_token, start_idx=start_idx) 

            check_out = parallel_f(
                args_list=list(zip(filter_list(final_answer, self.eot), summary_response, [1]*len(summary_response))), 
                num_workers=self.num_workers,
            )
            all_answers = filter_concat(all_answers, [x[1] for x in check_out], self.eot, unsqueeze=True)
            assert len(check_out) == np.sum(~self.eot).item(), f"check_out {check_out} should be same as sum(~self.eot) {self.eot}"
            summary_status = filter_fill(summary_status, [x[0] for x in check_out], self.eot)
            assert not np.any(summary_status[~self.eot] == -1), f"summary_status {summary_status} should not contain -1; eot {self.eot}; answer {answer}"
            for n, finished in enumerate(self.eot):
                if not finished:
                    if len(self.full_prompt_ids[n][-1]) < self.summary_min_token:
                        summary_status[n] = 0 # wrong format for summary that are too short

            full_summary_response = ["" for _ in range(batch_size)]
            full_summary_response = filter_fill(full_summary_response, summary_response, self.eot) # make it having full batch size

            if not verify_sum:
                self.eot = np.ones(batch_size, dtype=bool)
            else:
                self.eot = filter_fill(self.eot, [x[0] != 1 for x in check_out], self.eot)

        # step 5 - verify summary
        if not np.all(self.eot) and verify_sum:
            verify_sum_response = self.inner_generate(verify_sampling_params, self.verify_sum_template, input=question, answer=final_answer, response=full_summary_response, summary_max_token=self.summary_max_token, skip=None, start_idx=8)
            verify_sum_gt_answer = ["yes" if x == 1 else "no" for x in slow_answer_status]
            check_out = parallel_f(
                args_list=list(zip(verify_sum_gt_answer, verify_sum_response, [1]*len(verify_sum_response))), 
                num_workers=self.num_workers,
            )
            all_answers = filter_concat(all_answers, [x[1] for x in check_out], self.eot, unsqueeze=True)
            assert len(check_out) == np.sum(~self.eot).item(), f"check_out {check_out} should be same as sum(~self.eot) {self.eot}"
            verify_summary_status = filter_fill(verify_summary_status, [x[0] for x in check_out], self.eot)
            assert not np.any(verify_summary_status[~self.eot] == -1), f"verify_summary_status {verify_summary_status} should not contain -1; eot {self.eot}; answer {answer}"
            self.eot = np.ones(batch_size, dtype=bool)

        # compute rewards
        if not eval:
            rewards = [[] for _ in range(batch_size)]
            response_status = [[] for _ in range(batch_size)]
            for i in range(batch_size):
                reward_right_format = self.reward_right_format

                if fast_answer_status[i] != -1:
                    response_status[i].append(fast_answer_status[i].item())
                    if fast_answer_status[i] == 1:
                        rewards[i].append(1)
                    elif fast_answer_status[i] == 2:
                        rewards[i].append(reward_right_format)
                    else:
                        rewards[i].append(0)
                    if self.summary_nonstop_discount != 1. and not self.full_stop_reasons[i][0] == "stop":
                        rewards[i][-1] = rewards[i][-1] * self.summary_nonstop_discount
                    rewards[i][-1] = rewards[i][-1] * self.fast_reward_coef

                if verify_status[i] != -1:
                    response_status[i].append(verify_status[i].item())
                    if verify_status[i] == 1:
                        rewards[i].append(1)
                    elif verify_status[i] == 2:
                        rewards[i].append(reward_right_format)
                    else:
                        rewards[i].append(0)
                    if self.summary_nonstop_discount != 1. and not self.full_stop_reasons[i][1] == "stop":
                        rewards[i][-1] = rewards[i][-1] * self.summary_nonstop_discount

                if slow_answer_status[i] != -1:
                    response_status[i].append(slow_answer_status[i].item())
                    if slow_answer_status[i] == 1:
                        rewards[i].append(1)
                    elif slow_answer_status[i] == 2:
                        rewards[i].append(reward_right_format)
                    else:
                        rewards[i].append(0)
                    if self.summary_nonstop_discount != 1. and not self.full_stop_reasons[i][2] == "stop":
                        rewards[i][-1] = rewards[i][-1] * self.summary_nonstop_discount

                if summary_status[i] != -1:
                    response_status[i].append(summary_status[i].item())
                    if summary_status[i] == 1:
                        rewards[i].append(1)
                    elif summary_status[i] == 2:
                        rewards[i].append(reward_right_format)
                    else:
                        rewards[i].append(0)
                    if self.summary_nonstop_discount != 1. and not self.full_stop_reasons[i][3] == "stop":
                        rewards[i][-1] = rewards[i][-1] * self.summary_nonstop_discount
                    rewards[i][-1] = rewards[i][-1] * self.summary_reward_coef

                if verify_summary_status[i] != -1:
                    response_status[i].append(verify_summary_status[i].item())
                    if verify_summary_status[i] == 1:
                        rewards[i].append(1)
                    elif verify_summary_status[i] == 2:
                        rewards[i].append(reward_right_format)
                    else:
                        rewards[i].append(0)                    
                    if self.summary_nonstop_discount != 1. and not self.full_stop_reasons[i][4] == "stop":
                        rewards[i][-1] = rewards[i][-1] * self.summary_nonstop_discount

                assert len(rewards[i]) == len(self.full_prompt_ids[i]) // 2, f"rewards {rewards[i]} should be same as full_prompt_ids {self.full_prompt_ids[i]}; fast_answer_status {fast_answer_status[i]}, verify_status {verify_status[i]}, slow_answer_status {slow_answer_status[i]}, summary_status {summary_status[i]}"

        request_outputs = []
        for i in range(batch_size):  
            completion_output = SumCompletionOutput(
                index=0,
                text=self.full_prompts[i],
                token_ids=self.full_prompt_ids[i],
                cumulative_logprob=None,
                logprobs=None,         
                response_status=response_status[i] if not eval else None,
                rewards=rewards[i] if not eval else None,                            
                answer_status=answer_status[i].item(),
                all_answers=all_answers[i],
                final_answer=final_answer[i],                     
                stop_reason=self.full_stop_reasons[i],
                finish_reason=self.full_stop_reasons[i],
            )
            request_output = RequestOutput(
                request_id = 0,
                prompt = self.full_prompts[i][0],
                prompt_token_ids = self.full_prompt_ids[i][0],
                prompt_logprobs=None,
                outputs=[completion_output],
                finished=True,
                )
            request_outputs.append(request_output)

        return request_outputs
    
    def inner_generate(self, sampling_params, template, **kwargs):
        if np.all(self.eot): return []
        skip = kwargs.pop("skip", None)     
        start_idx = kwargs.pop("start_idx", None)   

        prompts_ = []
        for j, finished in enumerate(self.eot):
            if not finished:
                if isinstance(template, list):
                    template_ = template[j]
                else:
                    template_ = template
                prompts_.append(template_.format(**{k: v[j] if isinstance(v, list) else v for k, v in kwargs.items()}))
        prompt_token_ids_ = self._tokenizer(prompts_, padding=False, add_special_tokens=False)["input_ids"]

        i = 0
        prompts = []
        prompt_token_ids = []
        eot_ = self.eot.copy()
        for j, finished in enumerate(eot_):
            if not finished:
                if sum([len(x) for x in self.full_prompt_ids[j]]) + len(prompt_token_ids_[i]) > self.max_model_len:
                    self.eot[j] = True
                else:
                    if skip is None or not skip[j]:
                        prompts.append(prompts_[i])
                        prompt_token_ids.append(prompt_token_ids_[i])
                i += 1
        if np.all(self.eot): return []

        if skip is not None:
            skip_gen = np.logical_or(self.eot, skip)
        else:
            skip_gen = self.eot.copy()

        self.full_prompts = filter_concat(self.full_prompts, prompts, skip_gen, unsqueeze=True)        
        self.full_prompt_ids = filter_concat(self.full_prompt_ids, prompt_token_ids, skip_gen, unsqueeze=True)
        prompt_token_ids = join_list(filter_list(self.full_prompt_ids, skip_gen), start_idx=start_idx)        

        if np.any(~skip_gen):
            out = self.llm.generate(
                prompts=None, 
                prompt_token_ids=prompt_token_ids, 
                sampling_params=sampling_params, 
                use_tqdm=False
            )
            responses = [output.outputs[0].text for output in out] 
            #response_ids = [list(output.outputs[0].token_ids) for output in out] 
            response_ids = self._tokenizer(responses, add_special_tokens=False, padding=False)["input_ids"]
            stop_reasons = [output.outputs[0].finish_reason for output in out]

            self.full_prompts = filter_concat(self.full_prompts, responses, skip_gen, unsqueeze=True)        
            self.full_prompt_ids = filter_concat(self.full_prompt_ids, response_ids, skip_gen, unsqueeze=True)
            self.full_stop_reasons = filter_concat(self.full_stop_reasons, stop_reasons, skip_gen, unsqueeze=True)
        else:
            responses = []

        if skip is not None:
            self.full_prompts = filter_concat(self.full_prompts, "", ~skip, unsqueeze=True)        
            self.full_prompts = filter_concat(self.full_prompts, "", ~skip, unsqueeze=True)  # empty place holder for prompt and response  
            self.full_prompt_ids = filter_concat(self.full_prompt_ids, [], ~skip, unsqueeze=True)
            self.full_prompt_ids = filter_concat(self.full_prompt_ids, [], ~skip, unsqueeze=True) # empty place holder for prompt and response
            self.full_stop_reasons = filter_concat(self.full_stop_reasons, "stop", ~skip, unsqueeze=True)
            responses_ = ["" for _ in range(self.eot.shape[0])]
            responses_ = filter_fill(responses_, responses, skip_gen)
            responses = filter_list(responses_, self.eot)       
        return responses

if __name__ == "__main__":
    #model = "Qwen/Qwen2.5-1.5B"
    model = "/home/schk/RS/dualityz/.idea/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B_im"
    #model = "/root/xx_workspace/huggingface_cache/hub/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796"
    llm = SumLLMActor(model, tensor_parallel_size=1, prompt_type=0, max_model_len=16000)
    sampling_params = SamplingParams(n=1, temperature=0.6, max_tokens=2000)
    prompts = ["If $x + y = 1$ and $x - y = 3$, what is the value of $x - 2y$?",
            "If $\\sqrt{5 + x} + \\sqrt{20 - x} = 7$, what is the value of $(5 + x)(20 - x)$?",
            "Let $d_1 = a^2 + 2^a + a \cdot 2^{(a+1)/2}$ and $d_2 = a^2 + 2^a - a \cdot 2^{(a+1)/2}$. If $1 \le a \le 251$, how many integral values of $a$ are there such that $d_1 \cdot d_2$ is a multiple of $5$?"
            ]
    answers = ["4", "144", "50"]
    request_outputs = llm.generate(prompts=prompts, sampling_params=sampling_params, answers=answers)

    #prompt_token_ids = llm._tokenizer(prompts,padding=False)["input_ids"]
    #request_outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, answers=answers)    
    print(request_outputs)