import asyncio
import random
import re
import os
import time

import inspect
import openai
import numpy as np
import traceback
import argparse
import jsonlines
import datetime
from datasets import load_from_disk, load_dataset
import json
import multiprocessing
from tqdm import tqdm
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from openai import OpenAI
import multiprocessing
import sys
import copy
import math
import re
import itertools
from modelscope import AutoModelForSequenceClassification, AutoTokenizer, AutoModel
import torch
sys.path.append('..')
from global_utils import QwenPRM, auto_get_rm, rm_path_dict, MixJudge
from global_utils.utils import (
    generate_together,
    generate_openai,
    generate_with_references,
    DEBUG,
    generate_general,
    generate_general_rm
)
from moa_api import raw_moa_api, async_raw_moa_api
from pathlib import Path
from math_equivalence import is_equiv

MAX_PROCESSES = 8
DATA_PATH = './dataset/test-500.jsonl'
MODEL_LIST_DICT = {
    '6_small': ['Qwen2.5-7B-Instruct', 'glm-4-9b-chat', 'Qwen2-7B-Instruct', 'Meta-Llama-3.1-8B-Instruct',
                  'internlm2_5-7b-chat', 'gemma-2-9b-it'],
    '4_mid': ['Meta-Llama-3.3-70B-Instruct', 'Qwen2.5-32b-Instruct', 'gemma_3_27b_it', 'QwQ-32B'],
    '6_mid': ['Meta-Llama-3.3-70B-Instruct', 'Qwen2.5-32b-Instruct', 'gemma_3_27b_it', 'QwQ-32B', 'EXAONE-Deep-32B', ],
    '4_mid+2_small': ['Meta-Llama-3.3-70B-Instruct', 'Qwen2.5-32b-Instruct', 'gemma_3_27b_it', 'QwQ-32B',
                      'Qwen2.5-7B-Instruct', 'internlm2_5-7b-chat']
}


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval


def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def build_prompt(question):
    messages = [{"role": "system",
                 "content": "You are a math problem solver. Please solve the following math problem. Be sure to explain your solution in detail. The numerical values in the answer should be surrounded by \\boxed{}. The final answer should start with 'The answer is' and give the conclusion directly. Do not add any extra content."}]
    messages.append({"role": "user", "content": f"question:{question}"})
    return messages


def extract_answer(answer):
    ext_ans = remove_boxed(last_boxed_only_string(answer))
    if ext_ans is None:
        # if not match, select 'A'
        print(f'Not match, use final match')
        patterns = [r'[Aa]nswer is ([+-]?\d*\.?\d+)', '[Aa]nswer is (.*)', 'is ([+-]?\d*\.?\d+)', '([+-]?\d*\.?\d+)']
        for pattern in patterns:
            match = re.findall(pattern, answer)
            if match:
                result = match[-1]
                if pattern == '[Aa]nswer is (.*)' and result[-1] == '.':
                    return result[:-1]
                return result
        return str(-114514)
    return ext_ans

def single_agent_test(model, max_tokens=2048, logprobs=1, exp_suffix='', temperature=0.7):
    with open(DATA_PATH, 'r') as f:
        examples = f.readlines()
    exp_name = 'single_agent_8k' if max_tokens == 8192 else 'single_agent'
    result_dir = os.path.join('result', exp_name, model+exp_suffix)
    result_file = os.path.join(result_dir, f'result.json')
    summary_file = os.path.join(result_dir, f'summary.json')
    final_res = []
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)
    cnt = 0
    has_question_id_list = []
    if os.path.exists(result_file):
        with jsonlines.Reader(open(result_file, 'r', encoding='utf-8')) as f:
            for pre_r_dict in f:
                has_question_id_list.append(pre_r_dict['question_id'])
                final_res.append(is_equiv(pre_r_dict['answer'], pre_r_dict['pred']))
        cnt = len(has_question_id_list)
    for question_id, example in tqdm(enumerate(examples), total=len(examples)):
        if question_id in has_question_id_list:
            continue
        example = eval(example)
        question, answer = example['problem'], example['answer']
        print(f"Question: {question}")
        messages = build_prompt(question)

        response = generate_general(model, messages, max_tokens, temperature, logprobs=logprobs)
        if not isinstance(response, str):
            cumulative_logprob = response['cumulative_logprob']
            response = response['response']
        else:
            cumulative_logprob = None
        pred = extract_answer(response)
        cnt += 1
        is_correct = is_equiv(pred, answer)
        final_res.append(is_correct)
        res_dict = {'question_id': question_id, 'question': question,
                    'answer': answer, 'pred': pred, 'is_correct': is_correct, 'model_response': response,
                    'solution': example['solution'], 'level': example['level'],
                    'subject': example['subject'], 'mean_logprob': cumulative_logprob}
        sum_dict = {'corr': sum(final_res), 'wrong': len(final_res)-sum(final_res), 'acc':sum(final_res)/len(final_res)}
        with open(summary_file, "w") as fo:
            fo.write(json.dumps(sum_dict))
        with jsonlines.Writer(open(result_file, 'a', encoding='utf-8')) as f:
            f.write(res_dict)
    print(1)




parser = argparse.ArgumentParser()
parser.add_argument(
    "--mode",
    type=str,
    default="top-3",
    help="The mode for generate_moa_by_reward.",
)
parser.add_argument(
    "--json_path",
    type=str,
    default="/cpfs01/shared/mabasic/tangshengji/moa/gsm8k/result/Qwem2_rm.json",
    help="The json path for generate_moa_by_reward.",
)
parser.add_argument(
    "--model",
    type=str,
    default="Qwen2.5-7B-Instruct",
)
parser.add_argument(
    "--rm_model",
    type=str,
    default="Skywork-8B-Reward-Models",
)
parser.add_argument(
    "--exclude_agent",
    type=str,
    default="",
)
parser.add_argument(
    "--task",
    type=str,
    default='temp_test'
)
parser.add_argument(
    "--N",
    type=int,
    default=1
)
parser.add_argument(
    "--beam_width",
    type=int,
    default=2
)

parser.add_argument(
    "--batch_size",
    type=int,
    default=2
)
parser.add_argument(
    "--residual",
    action='store_true',
)
parser.add_argument(
    "--bi_greedy",
    action='store_true',
)
parser.add_argument(
    "--use_sc",
    action='store_true',
)
parser.add_argument(
    "--use_rm",
    action='store_true',
)
parser.add_argument(
    "--drop_greedy",
    action='store_true',
)
parser.add_argument(
    "--rounds",
    type=int,
    default=1
)
parser.add_argument(
    "--max_tokens",
    type=int,
    default=32768
)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.7
)
parser.add_argument(
    "--model_list",
    type=str,
    default="general_4_mid_llm",
)
parser.add_argument(
    "--logprobs",
    type=int,
    default=1
)
parser.add_argument(
    "--exp_suffix",
    type=str,
    default=''
)

args = parser.parse_args()
task_function = globals().get(args.task)
if callable(task_function):
    # get the parameters
    params = inspect.signature(task_function).parameters

    # build a dict containing the keys and values
    filtered_args = {key: getattr(args, key) for key in params if hasattr(args, key)}

    # callback
    task_function(**filtered_args)
else:
    print(f"Cannot find function {args.task}")