import asyncio
import json
import os
import re
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from itertools import islice, zip_longest
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from datetime import datetime

import numpy as np
import ray
import torch
from loguru import logger
from typing_extensions import override

from thinker_task.ppo.trainer import RayPPOTrainer
from thinker_task.ppo.tools.math_utils import is_equal, is_equal_sync, solution2answer, extract_boxed_answer
from thinker_task.ppo.utils import check_reflection_pattern, save_debug_data, join_str, join_ls_str
from thinker_task.exp_engine.accelerators.inference.utils import parallel_f

NUM_NODE = int(os.environ.get("NUM_NODE", 1))
executor = ThreadPoolExecutor(max_workers=64)

def repeatness(s: str):
    def ranks(l):
        index = {v: i for i, v in enumerate(sorted(set(l)))}
        return [index[v] for v in l]

    def suffixArray(s):
        line = ranks(s)
        n, k, ans, sa = len(s), 1, line, [0] * len(s)
        while k < n - 1:
            line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
            ans, k = line, k << 1
        for i, k in enumerate(ans):
            sa[k] = i
        return ans, sa

    def lcp(arr, suffixArr, inv_suff):
        n, ans, k = len(arr), [0] * len(arr), 0

        for i in range(n):
            if inv_suff[i] == n - 1:
                k = 0
                continue

            j = suffixArr[inv_suff[i] + 1]
            while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
                k += 1

            ans[inv_suff[i]] = k
            if k > 0:
                k -= 1

        return ans

    arr = [ord(i) for i in s]
    n = len(arr)
    if n <= 1:
        return 0
    c, sa = suffixArray(arr)
    cnt = sum(lcp(arr, sa, c))

    return cnt * 2 / (n * (n + 1))

@ray.remote(num_cpus=1)
def extract_final_answers_batch(prompt_type: int, responses: List[str]) -> List[str]:
    if prompt_type == 0:
        pattern = re.compile(r"<answer>.*?(\\boxed{.*}).*?</answer>", re.DOTALL)
        results = []
        for response in responses:
            matches = re.findall(pattern, response)
            results.append(matches[-1] if matches else "")
    else:
        results = []
        for response in responses:
            results.append(extract_boxed_answer(response))
    return results    

async def latex_equal(gt_answer, given_answer, math_mode="legacy"):
    if gt_answer.lower() in ["yes", "no"]:
        iscorrect = 1 if given_answer.lower() == gt_answer.lower() else 0
    else:
        gt_answer = solution2answer(gt_answer)
        given_answer = solution2answer(given_answer)
        iscorrect = await is_equal(gt_answer, given_answer, executor, math_mode=math_mode)
    return iscorrect

def latex_equal_sync(gt_answer, given_answer, math_mode="legacy"):
    if gt_answer.lower() in ["yes", "no"]:
        iscorrect = 1 if given_answer.lower() == gt_answer.lower() else 0
    else:
        gt_answer = solution2answer(gt_answer)
        given_answer = solution2answer(given_answer)
        iscorrect = is_equal_sync(gt_answer, given_answer, math_mode=math_mode)
    return iscorrect

async def check_multi_ans(gt_answer: str, answers: list):
    # Count occurrences of each answer using a dictionary, comparing answers using latex_equal
    answer_votes = defaultdict(int)
    for answer in answers:
        matched = False
        for existing_answer in answer_votes:
            if await latex_equal(answer, existing_answer, math_mode="math_verify"):
                matched = True
                break
        if not matched:
            answer_votes[answer] = 1
        else:
            answer_votes[answer] += 1

    # Find the answer with the most votes (first one in case of a tie)
    consensus_answer = max(answer_votes.items(), key=lambda x: (x[1], -answers.index(x[0])))[0]
    
    # Check if consensus answer matches the ground truth
    con_at_n = int(await latex_equal(gt_answer, consensus_answer), math_mode="math_verify")
    return con_at_n

def check_multi_ans_sync(gt_answer: str, answers: list):
    # Count occurrences of each answer using a dictionary, comparing answers using latex_equal
    answer_votes = defaultdict(int)
    for answer in answers:
        matched = False
        for existing_answer in answer_votes:
            if latex_equal_sync(answer, existing_answer, math_mode="math_verify"):
                matched = True
                break
        if not matched:
            answer_votes[answer] = 1
        else:
            answer_votes[answer] += 1

    # Find the answer with the most votes (first one in case of a tie)
    consensus_answer = max(answer_votes.items(), key=lambda x: (x[1], -answers.index(x[0])))[0]
    
    # Check if consensus answer matches the ground truth
    con_at_n = int(latex_equal_sync(gt_answer, consensus_answer, math_mode="math_verify"))
    return con_at_n

class CustomRewardTrainer(RayPPOTrainer):
    @override
    async def custom_reward_fn(
        self,
        prompts: List[str],
        outputs: List[Any],
        extras: List[dict],
    ) -> Tuple[List[str], List[str], List[torch.Tensor]]:
        # make log metrics
        scores = []
        responses = []
        avg_non_stop_count = 0        
        pass_at_n_dict = defaultdict(list)
        num_tokens: List[int] = []

        @ray.remote(num_cpus=1)
        def get_repeat_score(res):
            return repeatness(res)

        @ray.remote(num_cpus=1)
        def get_reflection_pattern_score(res):
            reflection_pattern_dict = check_reflection_pattern(res)
            reflection_pattern_num = sum(reflection_pattern_dict.values())
            return reflection_pattern_num

        rep_tasks = []
        for output in outputs:
            response = join_str(output["response"])
            # calculate repeat score for log
            rep_tasks.extend([get_repeat_score.remote(response), get_reflection_pattern_score.remote(response)])
        rep_task_results = ray.get(rep_tasks)

        repeat_scores = []
        reflection_pattern_scores = []
        for idx in range(len(outputs)):
            repeat_scores.append(rep_task_results[idx * 2])
            reflection_pattern_scores.append(rep_task_results[idx * 2 + 1])

        for output in outputs:
            responses.append(output["response"])
        
        output_tokens = self._tokenize(responses, self.cfg.generate_max_len, padding=False)

        example_text = f"prompts: {prompts[0]}\n\noutputs: {outputs[0]['response']}\n\nfinal_answer: {outputs[0]['final_answer']}\n\nis_correct: {outputs[0]['answer_status']==1}\n\nstop_reason: {outputs[0]['stop_reason']}\n\nresponse_token: {len(output_tokens[0])}"
        # self.writer.add_text("generated_raws", example_text, self.global_step,)

        for idx in range(len(outputs)):
            prompt, output, out_token = prompts[idx], outputs[idx], output_tokens[idx]
            rep_score, reflection_pattern_score = repeat_scores[idx], reflection_pattern_scores[idx]
            answer_status = output["answer_status"]
            stop_reason = output["stop_reason"]
            response_token = len(out_token)
            output["repeat_score"] = rep_score
            output["reflection_pattern_score"] = reflection_pattern_score

            # only correct and stoped response can aquire reward
            score = 0.0
            if stop_reason == "stop":
                if answer_status == 1:
                    score = 1.0
            else:
                avg_non_stop_count += 1
            scores.append(score)

            # calculate pass@n
            pass_at_n_dict[prompt].append(scores[-1])
            # log num_tokens
            num_tokens.append(response_token)

        # must before grpo, for grpo will change scores
        num_tokens_arr = np.array(num_tokens, dtype=np.float32)  # must be float to calculate mean and std
        scores_arr = np.array(scores)   
        correct_tokens_arr = np.array([]) if np.all(scores_arr == 0) else np.array(num_tokens_arr[scores_arr == 1])
        incorrect_tokens_arr = np.array([]) if np.all(scores_arr == 1) else np.array(num_tokens_arr[scores_arr == 0])

        # GRPO
        if self.cfg.use_grpo:
            # self.writer.add_scalar("grpo_raw_reward", np.mean(scores), self.global_step)
            # grpo reward normalization
            for i, prompt in enumerate(prompts):
                scores[i] -= np.mean(pass_at_n_dict[prompt])
                if std := np.std(pass_at_n_dict[prompt]) > 0:
                    scores[i] /= std

        def dump_results(prompts, outputs, scores):
            saved = []
            for prompt, output, score in zip(prompts, outputs, scores):
                saved.append(dict(prompt=prompt, score=score, outputs=output))
            json.dump(
                saved,
                open(os.path.join(self.cfg.save_path, f"iter{self.global_step}_generation_results.json"), "w"),
                ensure_ascii=False,
                indent=2,
            )

        global executor
        # asyncio.get_event_loop().run_in_executor(
        #     executor, dump_results, copy.deepcopy(prompts), copy.deepcopy(outputs), copy.deepcopy(scores)
        # )

        answer_acc = sum([1 for output in outputs if output["answer_status"] == 1]) / len(outputs) if len(outputs) > 0 else 0
        total_pass_at_n = sum(1 for v in pass_at_n_dict.values() if np.sum(v) > 0)
        log_dict = {
            "score": answer_acc,
            "non_stop_count": avg_non_stop_count / len(prompts),
            "repeat_score": sum(repeat_scores) / len(prompts),
            "reflection_pattern_score": sum(reflection_pattern_scores) / len(prompts),
            "pass_at_n": total_pass_at_n / len(pass_at_n_dict),
            "avg_pass": sum([np.mean(v) for v in pass_at_n_dict.values() if np.sum(v) > 0]) / total_pass_at_n if total_pass_at_n > 0 else 0,
            "num_tokens": np.mean(num_tokens_arr).item(),
            "std_num_tokens": np.std(num_tokens_arr).item(),
            "correct_num_tokens": 0 if len(correct_tokens_arr) == 0 else np.mean(correct_tokens_arr).item(),
            "std_correct_num_tokens": 0 if len(correct_tokens_arr) == 0 else np.std(correct_tokens_arr).item(),
            "incorrect_num_tokens": 0 if len(incorrect_tokens_arr) == 0 else np.mean(incorrect_tokens_arr).item(),
            "std_incorrect_num_tokens": 0 if len(incorrect_tokens_arr) == 0 else np.std(incorrect_tokens_arr).item(),
        }

        if self.cfg.multi_attempt:
            attempt_used = []
            attempt_used_success = []
            attempt_used_failure = []            
            for output in outputs:
                attempt_used.append(output["attempt_used"])
                if output["answer_status"] == 1:
                    attempt_used_success.append(output["attempt_used"])
                else:
                    attempt_used_failure.append(output["attempt_used"])                    
            log_dict["attempt_used"] = sum(attempt_used) / len(attempt_used)
            if len(attempt_used_success) > 0:
                log_dict["attempt_used_success"] = sum(attempt_used_success) / len(attempt_used_success)
            else:
                log_dict["attempt_used_success"] = 0
            if len(attempt_used_failure) > 0:
                log_dict["attempt_used_failure"] = sum(attempt_used_failure) / len(attempt_used_failure)
            else:
                log_dict["attempt_used_failure"] = 0        

        # for k, v in log_dict.items():
        #     self.writer.add_scalar(k, v, self.global_step)
        logging_str = ",".join([f"{k}: {v:.4f}" for k, v in log_dict.items()])
        logger.info(logging_str)

        # make histogram for correct and incorrect response length
        # if len(correct_tokens_arr) > 0:
        #    self.writer.add_histogram("correct_response_length", correct_tokens_arr, self.global_step)
        # if len(incorrect_tokens_arr) > 0:
        #    self.writer.add_histogram("incorrect_response_length", incorrect_tokens_arr, self.global_step)

        # make a pre-token score tensor for each output, for example: [0, 0, 0, 0, r]
        score_tensors = []
        for score, output_token in zip(scores, output_tokens):
            score_tensor = torch.zeros(len(output_token))
            if len(output_token) > 0:
                score_tensor[-1] = score
            score_tensors.append(score_tensor)

        # rm empty response
        res_prompts = []
        res_responses = []
        res_score_tensors = []
        info = {
            "answer_status": [],            
            "response_type": [],
        }

        for n, (prompt, response, score_tensor) in enumerate(zip(prompts, responses, score_tensors)):
            if len(response) > 0:
                res_prompts.append(prompt)
                res_responses.append(response)
                res_score_tensors.append(score_tensor)        
                info["response_type"].append(0)
                info["answer_status"].append(outputs[n]["answer_status"])        

        log_dict["table"] = {
            "prompt": prompts[0],
            "response": join_str(outputs[0]['response']),
            "final_answer": outputs[0]['final_answer'],
            "answer_status": outputs[0]['answer_status'],
            "stop_reason": outputs[0]['stop_reason'],
            "score": scores[0],
        }
        if self.cfg.multi_attempt:
            log_dict["table"]["attempt_used"] = outputs[0]['attempt_used']

        return res_prompts, res_responses, res_score_tensors, info, log_dict
    
    @override
    async def summary_reward_fn(
        self,
        prompts: List[str],
        outputs: List[Any],
        extras: List[dict],
    ) -> Tuple[List[str], List[str], List[torch.Tensor]]:
        # make log metrics
        pass_at_n_dict = defaultdict(list)
        
        
        @ray.remote(num_cpus=1)
        def get_repeat_score(res):
            return repeatness(res)

        @ray.remote(num_cpus=1)
        def get_reflection_pattern_score(res):
            reflection_pattern_dict = check_reflection_pattern(res)
            reflection_pattern_num = sum(reflection_pattern_dict.values())
            return reflection_pattern_num

        rep_tasks = []
        for output in outputs:
            response = join_str(output["response"])
            # calculate repeat score for log
            rep_tasks.extend([get_repeat_score.remote(response), get_reflection_pattern_score.remote(response)])
        rep_task_results = ray.get(rep_tasks)        

        repeat_scores = []
        reflection_pattern_scores = []
        for idx in range(len(outputs)):
            repeat_scores.append(rep_task_results[idx * 2])
            reflection_pattern_scores.append(rep_task_results[idx * 2 + 1])
            pass_at_n_dict[prompts[idx]].append(float(outputs[idx]["answer_status"] == 1))

        """
        def get_repeat_score(res):
            return repeatness(res)

        def get_reflection_pattern_score(res):
            reflection_pattern_dict = check_reflection_pattern(res)
            reflection_pattern_num = sum(reflection_pattern_dict.values())
            return reflection_pattern_num

        repeat_scores = []
        reflection_pattern_scores = []

        for idx, output in enumerate(outputs):
            response = join_str(output["response"])            
            repeat_score = get_repeat_score(response)
            reflection_pattern_score = get_reflection_pattern_score(response)            
            repeat_scores.append(repeat_score)
            reflection_pattern_scores.append(reflection_pattern_score)
            pass_at_n_dict[prompts[idx]].append(float(outputs[idx]["answer_status"] == 1))
        """

        answer_acc = sum([1 for output in outputs if output["answer_status"] == 1]) / len(outputs) if len(outputs) > 0 else 0
        total_pass_at_n = sum(1 for v in pass_at_n_dict.values() if np.sum(v) > 0)
        log_dict = {
            "score": answer_acc,
            "repeat_score": sum(repeat_scores) / len(prompts),
            "reflection_pattern_score": sum(reflection_pattern_scores) / len(prompts),
            "pass_at_n": total_pass_at_n / len(pass_at_n_dict),
            "avg_pass": sum([np.mean(v) for v in pass_at_n_dict.values() if np.sum(v) > 0]) / total_pass_at_n if total_pass_at_n > 0 else 0,
        }      

        modes = ["fast", "verify", "slow", "summary"]
        for n, mode in enumerate(modes):
            num_response = sum([1 for output in outputs if len(output["response_status"]) > n])
            num_corr_response = sum([1 for output in outputs if len(output["response_status"]) > n and output["response_status"][n] == 1])
            num_incorr_response = num_response - num_corr_response
            non_stop_count = sum([1 for output in outputs if len(output["response_status"]) > n and output["stop_reason"][n] != "stop"])

            if num_response > 0:
                num_token = sum([len(output["response_ids"][1::2][n]) for output in outputs if len(output["response_ids"][1::2]) > n])
                log_dict["{}_num_tokens".format(mode)] = num_token / num_response
                log_dict["{}_acc".format(mode)] = num_corr_response / num_response
                log_dict["{}_non_stop_count".format(mode)] = non_stop_count / num_response
            else:
                log_dict["{}_num_tokens".format(mode)] = 0
                log_dict["{}_acc".format(mode)] = 0
                log_dict["{}_non_stop_count".format(mode)] = 0
            
            if num_corr_response > 0:
                num_token = sum([len(output["response_ids"][1::2][n]) for output in outputs if len(output["response_ids"][1::2]) > n and output["response_status"][n] == 1])
                log_dict["{}_correct_num_tokens".format(mode)] = num_token / num_corr_response
            else:
                log_dict["{}_correct_num_tokens".format(mode)] = 0

            if num_incorr_response > 0:
                num_token = sum([len(output["response_ids"][1::2][n]) for output in outputs if len(output["response_ids"][1::2]) > n and output["response_status"][n] != 1])
                log_dict["{}_incorrect_num_tokens".format(mode)] = num_token / num_incorr_response
            else:
                log_dict["{}_incorrect_num_tokens".format(mode)] = 0
        
        verify_pos_num = sum([1 for output in outputs if len(output["response_status"]) > 1 and output["response_status"][0] == 1])
        verify_pos_corr_num = sum([1 for output in outputs if len(output["response_status"]) > 1 and output["response_status"][0] == 1 and output["response_status"][1] == 1])
        verify_neg_num = sum([1 for output in outputs if len(output["response_status"]) > 1 and output["response_status"][0] != 1])
        verify_neg_corr_num = sum([1 for output in outputs if len(output["response_status"]) > 1 and output["response_status"][0] != 1 and output["response_status"][1] == 1])
        log_dict["verify_pos_acc"] = verify_pos_corr_num / verify_pos_num if verify_pos_num > 0 else 0
        log_dict["verify_neg_acc"] = verify_neg_corr_num / verify_neg_num if verify_neg_num > 0 else 0

        #for k, v in log_dict.items():
        #    self.writer.add_scalar(k, v, self.global_step)
        logging_str = ",".join([f"{k}: {v:.4f}" for k, v in log_dict.items()])
        logger.info(logging_str)

        if not hasattr(self, "trail_fast_acc"):
            self.trail_fast_acc = deque(maxlen=10)
        self.trail_fast_acc.append(log_dict["fast_acc"])
        mean_fast_acc = np.mean(self.trail_fast_acc).item()

        # make a pre-token score tensor for each output, for example: [0, 0, 0, 0, r]
        score_tensors = []
        for output in outputs:
            score_tensor = []
            for i in range(len(output["response_status"])):
                if i == 0:
                    token_len = len(output["response_ids"][1::2][i]) # response 
                else:
                    token_len = len(output["response_ids"][::2][i]) + len(output["response_ids"][1::2][i]) # prompt + response
                score_tensor_i = torch.zeros(token_len)
                if token_len > 0:
                    r = output["rewards"][i]
                    if self.cfg.verify_reweight and i == 1:
                        if output["response_status"][0] == 1:
                            r = r * (1 - mean_fast_acc)
                        else:
                            r = r * mean_fast_acc
                    score_tensor_i[-1] = r
                score_tensor.append(score_tensor_i)
            assert len(torch.cat(score_tensor)) == sum([len(x ) for x in output["response_ids"][1:]]), f"len(score_tensor): {len(score_tensor)}, len(output['response_ids'][1:]): {sum([len(x ) for x in output['response_ids'][1:]])}"
            score_tensors.append(torch.cat(score_tensor))

        res_prompts = []
        res_responses = []
        res_score_tensors = []
        info = {
            "answer_status": [],            
            "response_status": [],
            "response_ids": [],
        }

        for output, score_tensor in zip(outputs, score_tensors):
            response = output["response"]
            if len(response) > 0:
                res_prompts.append(output["prompt"])
                res_responses.append(output["response"])
                res_score_tensors.append(score_tensor)     
                info["answer_status"].append(output["answer_status"])        
                info["response_status"].append(output["response_status"])        
                info["response_ids"].append(output["response_ids"])

        log_dict["table"] = {
            "prompt": prompts[0],
            "response": join_str(outputs[0]['response'][1:]),
            "final_answer": outputs[0]['final_answer'],
            "response_status": ",".join([str(x) for x in outputs[0]['response_status']]),
            "answer_status": outputs[0]['answer_status'],
            "stop_reason": "/".join(outputs[0]['stop_reason']),
            "score": ",".join([str(x) for x in outputs[0]['rewards']]),
        }

        return res_prompts, res_responses, res_score_tensors, info, log_dict    

    @override
    @torch.no_grad()
    async def generate_vllm(
        self,
        gen_func: Callable[[List[str]], Awaitable[List[str | Any]]],
        prompts: List[str],
        extras: List[dict],
        **kwargs,
    ) -> List[str | Any]:
        from vllm import SamplingParams

        # read sampling params from self.cfg

        sampling_params = SamplingParams(
            temperature=self.cfg.temperature,
            top_p=self.cfg.top_p,
            top_k=self.cfg.top_k,
            max_tokens=self.cfg.generate_max_len,
            skip_special_tokens=False,
            include_stop_str_in_output=True,
            stop=list(self.cfg.stop),
        )

        if not self.cfg.multi_attempt and not self.cfg.summary:            
            out = await gen_func(
                prompts=prompts, sampling_params=sampling_params, use_tqdm=False, truncate_prompt=True
            )
            responses, stop_reasons = out["response"], out["stop_reason"]

            BATCH_SIZE = 16
            num_batches = (len(responses) + BATCH_SIZE - 1) // BATCH_SIZE

            # 直接从context中提取最终结果
            extract_tasks = []
            for i in range(num_batches):
                start_idx = i * BATCH_SIZE
                end_idx = min((i + 1) * BATCH_SIZE, len(responses))
                batch = responses[start_idx:end_idx]
                extract_tasks.append(extract_final_answers_batch.remote(self.cfg.prompt_type, batch))
            batched_results = await asyncio.gather(*[asyncio.to_thread(ray.get, task) for task in extract_tasks])
            final_answers = [answer for batch in batched_results for answer in batch]

            # 判断对错
            global executor
            equal_tasks = []
            for extra, final_answer in zip(extras, final_answers):
                equal_tasks.append(is_equal(solution2answer(extra["answer"]), solution2answer(final_answer), executor))
            equal_results = await asyncio.gather(*equal_tasks)

            results = []
            for extra, response, final_answer, stop_reason, equal_result in zip(
                extras, responses, final_answers, stop_reasons, equal_results
            ):
                if equal_result:
                    answer_status = 1 # correct answer
                elif final_answer:
                    answer_status = 2 # correct format, wrong answer
                else:
                    answer_status = 0 # wrong format, wrong answer

                results.append(
                    dict(
                        response=response,
                        answer_status=answer_status,
                        stop_reason=stop_reason,
                        final_answer=final_answer,
                    )
                )

            return results

        else:
            answers = [x['answer'] for x in extras]
            out = await gen_func(
                prompts=prompts, sampling_params=sampling_params, use_tqdm=False, truncate_prompt=True, answers=answers,
            )
            results = []
            for n in range(len(out["response"])):
                results.append({k: v[n] for k, v in out.items()})
            return results        
        
    async def record_eval(self, prompts_all, file_names_all, outputs_all, final_answers_all, gt_answers_all, prefix=""):

        output_for_save = []
        log_dict = defaultdict(float)
        
        iscorrects = []
        for n, (prompt, output, final_answer, answer, file_name) in enumerate(zip(
            prompts_all, outputs_all, final_answers_all, gt_answers_all, file_names_all
        )):
            # iscorrect = await latex_equal(answer, final_answer, math_mode="math_verify") 
            iscorrect = await latex_equal(answer, final_answer, math_mode="both" if file_name == "olympiadbench" else "legacy") # olympiadbench has multipl answers, so need different mode
            iscorrects.append(iscorrect)
            output_for_save.append(
                dict(
                    prompt=prompt,
                    output=output.text,
                    final_answer=final_answer,
                    answer=answer,
                    iscorrect=iscorrect,
                )
            )

            token_ids = output.token_ids
            if not self.cfg.summary:
                response_len = len(token_ids)
            else:
                if prefix == "":
                    # sum over outout[1, 3, 5] if they exist                    
                    response_len = sum([len(token_ids[i]) for i in range(1, min(len(token_ids), 6), 2)]) if len(token_ids) > 1 else 0
                elif prefix == "pre_":
                    response_len = len(token_ids[1]) if len(token_ids) > 1 else 0
                elif prefix == "verify_":
                    response_len = len(token_ids[3]) if len(token_ids) > 3 else 0

            log_dict[f"{file_name}_total_response_length"] += response_len
            log_dict[f"{file_name}_correct"] += iscorrect
            log_dict[f"{file_name}_total"] += 1

            if prefix == "" and self.cfg.summary:
                if len(token_ids) > 5:
                    log_dict[f"slow_{file_name}_total_response_length"] += len(token_ids[5])
                    log_dict[f"slow_{file_name}_correct"] += iscorrect
                    log_dict[f"slow_{file_name}_total"] += 1

        # get all file_names from self.cfg.eval_prompt_data
        all_file_names: List[str] = [
            os.path.splitext(os.path.basename(file_path))[0] for file_path in self.cfg.eval_prompt_data
        ]
        for file_name in all_file_names:
            log_dict[f"{file_name}_response_length"] = (
                log_dict[f"{file_name}_total_response_length"] / log_dict[f"{file_name}_total"]
            )
            log_dict[f"{file_name}_acc"] = log_dict[f"{file_name}_correct"] / log_dict[f"{file_name}_total"] * 100
            log_dict.pop(f"{file_name}_total_response_length")
            log_dict.pop(f"{file_name}_correct")
            log_dict.pop(f"{file_name}_total")

            if prefix == "" and self.cfg.summary:
                log_dict[f"slow_{file_name}_response_length"] = (
                    log_dict[f"slow_{file_name}_total_response_length"] / log_dict[f"slow_{file_name}_total"]
                )
                log_dict[f"slow_{file_name}_acc"] = log_dict[f"slow_{file_name}_correct"] / log_dict[f"slow_{file_name}_total"] * 100
                log_dict.pop(f"slow_{file_name}_total_response_length")
                log_dict.pop(f"slow_{file_name}_correct")
                log_dict.pop(f"slow_{file_name}_total")


        # calculate average accuracy
        log_dict["avg_acc"] = sum([log_dict[f"{file_name}_acc"] for file_name in all_file_names]) / len(
            all_file_names
        )
        log_dict["avg_response_length"] = sum([log_dict[f"{file_name}_response_length"] for file_name in all_file_names]) / len(
            all_file_names
        )        

        dump_file_name = f"eval_output_iter{self.global_step}"
        dump_file_name = prefix + dump_file_name
        # join all acc from all_file_names
        for file_name in all_file_names:
            dump_file_name += f"_{file_name}{log_dict[f'{file_name}_acc']:.4f}"
        dump_file_name += ".jsonl"
        # dump as jsonl
        with open(
            os.path.join(
                self.cfg.save_path,
                dump_file_name,
            ),
            "w",
        ) as f:
            for item in output_for_save:
                f.write(
                    json.dumps(item, ensure_ascii=False) + "\n",
                )

        log_dict = {prefix + k if not k.startswith("slow") else k: v for k, v in log_dict.items()}
        logging_str = ",".join([f"{k}: {v:.4f}" for k, v in log_dict.items()])
        logger.info(logging_str)
        for k, v in log_dict.items():
            self.writer.add_scalar(f"eval/{k}", v, self.global_step)
        return log_dict, iscorrects   
    
    @override
    async def eval(self, multi_iter=1):
        logger.info("Start evaluating on val set")
        from vllm import SamplingParams

        sampling_params = SamplingParams(
            temperature=self.cfg.temperature,
            top_p=self.cfg.top_p,
            max_tokens=self.cfg.generate_max_len,
            stop=list(self.cfg.stop),
            skip_special_tokens=False,
            include_stop_str_in_output=True,
        )

        from torch.utils.data import DataLoader, ConcatDataset

        dataset = self.eval_dataset
        if multi_iter > 1:
            dataset = ConcatDataset([dataset] * multi_iter)

        dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, drop_last=False)
        prompt_pre_llm = (len(dataset) + self.cfg.vllm_num_engines - 1) // self.cfg.vllm_num_engines        

        prompts_all = []
        file_names_all = []        
        gt_answers_all = []

        outputs_all = []
        final_answers_all = []       
        
        # summary specific data
        verified_final_answers_all = []
        fast_answers_all = []
        verified_fast_answers_all = []        
        verify_answers_all = []

        for batch in dataloader:
            prompts = list(batch[0])
            answers = list(batch[1]["answer"])
            file_names = list(batch[1]["file_name"])

            prompts_all.extend(prompts)
            gt_answers_all.extend(answers)   
            file_names_all.extend(file_names)                     

            vllm_outputs = []
            for i, llm in enumerate(self.vllm_engines):
                kwargs = dict(
                    prompts=prompts[i * prompt_pre_llm : (i + 1) * prompt_pre_llm], sampling_params=sampling_params
                )
                if self.cfg.summary and multi_iter > 1:
                    kwargs["verify_sum"] = True

                vllm_outputs.append(llm.generate.remote(**kwargs))
            vllm_outputs = await asyncio.gather(*vllm_outputs)
            vllm_outputs = sum(vllm_outputs, [])
            
            for n, vllm_output in enumerate(vllm_outputs):
                if self.cfg.summary:
                    output = vllm_output.outputs[0].text
                    fast_answer = vllm_output.outputs[0].all_answers[0]
                    verify_answer = vllm_output.outputs[0].all_answers[1] if len(vllm_output.outputs[0].all_answers) >= 2 else ""
                    if len(output) <= 4:
                        # verified fast response
                        final_answer = vllm_output.outputs[0].all_answers[0]
                        verified_final_answer = vllm_output.outputs[0].all_answers[0]
                        verified_fast_answer = vllm_output.outputs[0].all_answers[0]
                    else:
                        # unverified slow response
                        final_answer = vllm_output.outputs[0].all_answers[2]
                        #logger.info(f"n:{n} all_answers: {vllm_output.outputs[0].all_answers}")
                        if len(vllm_output.outputs[0].all_answers) >= 5 and vllm_output.outputs[0].all_answers[4] == "yes":
                            verified_final_answer = vllm_output.outputs[0].all_answers[2]
                        else:
                            verified_final_answer = None
                        verified_fast_answer = None

                    final_answers_all.append(final_answer)
                    verified_final_answers_all.append(verified_final_answer)
                    fast_answers_all.append(fast_answer)
                    verified_fast_answers_all.append(verified_fast_answer)                    
                    verify_answers_all.append(verify_answer)

                else:
                    final_answer_output = vllm_output.outputs[0].text
                    final_answer = extract_boxed_answer(final_answer_output)
                    final_answers_all.append(final_answer)
  
                outputs_all.append(vllm_output.outputs[0])

        log_dict  = {}
        log_dict_, iscorrects_final = await self.record_eval(prompts_all, file_names_all, outputs_all, final_answers_all, gt_answers_all, prefix="")
        log_dict.update(log_dict_)

        if self.cfg.summary:
            log_dict_, iscorrects_fast = await self.record_eval(prompts_all, file_names_all, outputs_all, fast_answers_all, gt_answers_all, prefix="pre_")
            log_dict.update(log_dict_)
            gt_verify_answers_all = ["yes" if iscorrect == 1 else "no" for iscorrect in iscorrects_fast]
            log_dict_, _ = await self.record_eval(prompts_all, file_names_all, outputs_all, verify_answers_all, gt_verify_answers_all, prefix="verify_")
            log_dict.update(log_dict_)

        if multi_iter > 1:
            # compute consistency loss, verified fast consistency loss, verified slow consistency loss

            aggregated_data = defaultdict(lambda: {
                'gt_answers': [],
                'final_answers': [],                
                'verified_final_answers': [],
                'fast_answers': [],
                'verified_fast_answers': [],    
                'iscorrects_final': [],
                'iscorrects_fast': [],
            })
            logger.info("Start aggregating data for computing pass_at_n and con_at_n")
            # Iterate through all the entries
            for i in range(len(prompts_all)):
                prompt = prompts_all[i]
                file_name = file_names_all[i]
                
                # Use a tuple (file_name, prompt) as the key
                key = (file_name, prompt)
                
                # Append the corresponding values into the lists for each key
                aggregated_data[key]['gt_answers'].append(gt_answers_all[i])
                aggregated_data[key]['final_answers'].append(final_answers_all[i])
                aggregated_data[key]['iscorrects_final'].append(iscorrects_final[i])

                if self.cfg.summary:                    
                    aggregated_data[key]['fast_answers'].append(fast_answers_all[i])
                    aggregated_data[key]['verified_final_answers'].append(verified_final_answers_all[i])
                    aggregated_data[key]['verified_fast_answers'].append(verified_fast_answers_all[i])                    
                    aggregated_data[key]['iscorrects_fast'].append(iscorrects_fast[i])

            logger.info("Computing pass_at_n")
            # compute pass_at_n
            for n, (k, v) in enumerate(aggregated_data.items()):
                v['pass_at_n'] = int(any(v['iscorrects_final']))
                if self.cfg.summary:
                    v['v_fast_con_at_n'] = int(any(v['iscorrects_fast']))

            logger.info("Computing con_at_n")
            # compute con_at_n
            gt_answers, final_answers = [], []
            verified_fast_answers, verified_final_answers = [], []
            for n, (k, v) in enumerate(aggregated_data.items()):
                gt_answers.append(v['gt_answers'][0])
                final_answers.append(v['final_answers'])
                if self.cfg.summary:
                    verified_fast_answers_ = [answer for answer in v['verified_fast_answers'] if answer is not None]
                    if len(verified_fast_answers_) == 0: verified_fast_answers_ = v['fast_answers']         
                    verified_fast_answers.append(verified_fast_answers_)

                    verified_final_answers_ = [answer for answer in v['verified_final_answers'] if answer is not None]
                    if len(verified_final_answers_) == 0: verified_final_answers_ = v['final_answers']
                    verified_final_answers.append(verified_final_answers_)

            logger.info("Computing con_at_n for final answers")
            con_at_n = parallel_f(list(zip(gt_answers, final_answers)), num_workers=8, timeout=180, default_value=False, f=check_multi_ans_sync)
            # logger.info(f"con_at_n for final answers: {con_at_n}")
            if self.cfg.summary:
                logger.info("Computing con_at_n for verified fast answers")
                v_fast_con_at_n = parallel_f(list(zip(gt_answers, verified_fast_answers)), num_workers=8, timeout=180, default_value=False, f=check_multi_ans_sync)
                logger.info("Computing con_at_n for verified final answers")
                v_con_at_n = parallel_f(list(zip(gt_answers, verified_final_answers)), num_workers=8, timeout=180, default_value=False, f=check_multi_ans_sync)
            
            for n, (k, v) in enumerate(aggregated_data.items()):
                v['con_at_n'] = int(con_at_n[n])
                if self.cfg.summary:
                    v['v_con_at_n'] = int(v_con_at_n[n])
                    v['v_fast_con_at_n'] = int(v_fast_con_at_n[n])

            log_dict_multi = {}
            # List of statistic keys to compute and average
            stat_keys = ["pass_at_n", "con_at_n"]            
            if self.cfg.summary:
                stat_keys.extend(["v_fast_con_at_n", "v_con_at_n"])

            # Iterate through the aggregated data and accumulate the statistics for each dataset
            for k, v in aggregated_data.items():
                dataset_name = k[0]  # Assuming dataset is part of the key (file_name, prompt)
                
                # Iterate over each statistic key
                for stat_key in stat_keys:
                    # Get the value for the current statistic key, defaulting to 0 if not found
                    stat_value = v.get(stat_key, 0)
                    
                    # Update log_dict with this statistic for the corresponding dataset
                    log_key = f"{dataset_name}_{stat_key}"
                    
                    if log_key not in log_dict_multi:
                        log_dict_multi[log_key] = []
                    log_dict_multi[log_key].append(stat_value)

            # After populating log_dict with all statistics, compute the averages
            for key, values in log_dict_multi.items():
                log_dict_multi[key] = sum(values) / len(values) if values else 0  # Compute average                
            
            logging_str = ",".join([f"{k}: {v:.4f}" for k, v in log_dict_multi.items()])
            logger.info(logging_str)

            log_dict.update(log_dict_multi)

        return log_dict
    
    