from collections import Counter
import os
import time
from typing import Any, Dict, List, Optional
import random
import re
import string

import datasets
import openai
from openai import OpenAI

# -----------------------------------------------------------------------------
# 1. Prompt 模板配置
# -----------------------------------------------------------------------------
QUERY_TEMPLATE = "{Question}\n\n{Options}"
QUERY_TEMPLATE_API = "{Question}\nAnswer Choices:\n{Options}"

# 思维链（CoT）控制逻辑
if os.getenv("PROMPTLONG") is not None:
    QUERY_TEMPLATE += '\n\nAnswer after a long amount of thinking. If you feel like you are finished early, spend the extra time trying to double-check your work until you are absolutely sure that you have the correct answer.'
elif os.getenv("PROMPTSHORT") is not None:
    QUERY_TEMPLATE += '\n\nAnswer after a short amount of thinking. Do not spend excessive time double-checking your work.'
elif os.getenv("PROMPTTOKEN") is not None:
    QUERY_TEMPLATE += f'\n\nThink for up to ' + os.getenv("PROMPTTOKEN") + ' tokens.'
elif os.getenv("PROMPTSTEP") is not None:
    QUERY_TEMPLATE += f'\n\nThink for up to ' + os.getenv("PROMPTSTEP") + ' steps.'

print("QUERY_TEMPLATE: ", QUERY_TEMPLATE)

ANSWER_PATTERN = r"(?i)Answer\s*:\s*(.*)"
VALID_LETTERS = list(string.ascii_uppercase) # ['A', 'B', ..., 'Z']

EXTRACTION_TEMPLATE = r"""
Look at the following question and an attempt by a student and extract which choice the student picked. If the student did not pick any choice, respond with "-1".

The options are labeled with alphabetic letters (A, B, C, D, E, ...).

Examples:

    Question: ...
    Attempt: Answer: **A**

A

    Question: ...
    Attempt: ...The answer is therefore Elephant...

B

    Question: ...
    Attempt: Answer: None of the above

-1

    Question: ...
    Attempt: ...Answer: D), because...

D

---

YOUR TASK

Respond only with the capitalized alphabetic letter (without quotes) or -1. Do not include a rationale.

    Question: %(expression1)s
    Attempt: %(expression2)s
""".strip()

def extract_answer(sampler, question: str, attempt: str):
    prompt = EXTRACTION_TEMPLATE % {"expression1": question, "expression2": attempt}
    response = sampler([dict(content=prompt, role="user")])
    return response

class ChatCompletionSampler:
    """
    Sample from OpenAI's chat completion API
    """

    def __init__(
        self,
        model: str = "gpt-4o-mini",
        system_message: str | None = None,
        temperature: float = 0.5,
        max_tokens: int = 1024,
    ):
        self.api_key_name = "OPENAI_API_KEY"
        # 请根据实际情况修改 Endpoint 和 Key
        self.client = openai.AzureOpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"),
            api_version="2025-01-01-preview",
            azure_endpoint="https://sub1-ai.openai.azure.com"
        )
        self.model = "gpt-4o-mini-2"
        self.system_message = "You are a helpful and harmless assistant."
        self.temperature = temperature
        self.max_tokens = max_tokens

    def __call__(self, message_list) -> str:
        if self.system_message:
            message_list = [{"role": "system", "content": self.system_message}] + message_list
        trial = 0
        while True:
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=message_list,
                    temperature=self.temperature,
                    max_tokens=self.max_tokens,
                )
                return response.choices[0].message.content
            except Exception as e:
                exception_backoff = 2**trial
                print(f"Rate limit or Error: wait {exception_backoff}s. Error: {e}")
                time.sleep(exception_backoff)
                trial += 1
                if trial > 5:
                    return ""

# -----------------------------------------------------------------------------
# 2. 处理结果逻辑
# -----------------------------------------------------------------------------
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
    metrics = {"exact_match": None, "extracted_answers": []}
    
    if isinstance(results[0], list):
        results = results[0]
        n_res = len(results)
        n_res_list = [2**i for i in range(1, int(n_res.bit_length()))]
        metrics = {
            **metrics,
            "exact_matches": [],
            **{f"cov@{n}": -1 for n in n_res_list},
            **{f"maj@{n}": -1 for n in n_res_list},
        }

    if os.getenv("PROCESSOR", "") == "gpt-4o-mini":
        sampler = ChatCompletionSampler(model="gpt-4o-mini")
        question = QUERY_TEMPLATE_API.format(Question=doc["Question"], Options=doc["formatted_options_api"])
    else:
        sampler = None

    split_tokens = ["<|im_start|>answer\n", "<|im_start|>"]
    for i, a in enumerate(results, start=1):
        for t in split_tokens:
            if t in a:
                a = a.split(t)[-1]
                if "\n" in a:
                    a = "\n".join(a.split("\n")[1:])
        
        if (box := last_boxed_only_string(a)) is not None:
            a = remove_boxed(box)
        elif (matches := re.findall(ANSWER_PATTERN, a, re.DOTALL)) != []:
            a = matches[-1]

        if len(a) == 1 and a.lower() in string.ascii_lowercase:
            a = a.upper()

        if a not in VALID_LETTERS:
            if sampler is not None:
                a = extract_answer(sampler, question, a)
            else:
                pass 

        if a not in VALID_LETTERS:
            print(f"Warning: Default to A as given {results[i-1][:50]}... extracted {a}")
            a = "A"

        metrics["extracted_answers"].append(a)
        
        is_correct = int(a == doc["answer"])
        
        if not(is_correct): 
             pass 
             
        if i == 1:
            metrics["exact_match"] = is_correct
            if "exact_matches" in metrics:
                metrics["exact_matches"].append(is_correct)
        elif i > 1:
            metrics["exact_matches"].append(is_correct)
            if i in n_res_list:
                metrics[f"cov@{i}"] = int(1 in metrics["exact_matches"])
                most_common = Counter(metrics["extracted_answers"]).most_common(1)
                if most_common:
                    metrics[f"maj@{i}"] = int(doc["answer"] == most_common[0][0])
                else:
                    metrics[f"maj@{i}"] = 0

    return metrics

# -----------------------------------------------------------------------------
# 3. 核心修改：process_docs 适配新数据结构 + 多进程加速
# -----------------------------------------------------------------------------
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    def _process_doc(doc):
        # 1. 获取原始选项列表
        options = doc["options"] 

        # 2. 移除打乱逻辑 (User instruction: dataset already randomized)
        # random.shuffle(options) 

        # 3. 获取正确答案字母
        answer_letter = doc.get("answer_letter")
        
        if not answer_letter:
            correct_text = doc["answer"]
            try:
                correct_index = -1
                for idx, opt in enumerate(options):
                    if opt.strip() == correct_text.strip():
                        correct_index = idx
                        break
                
                if correct_index != -1:
                    answer_letter = chr(65 + correct_index)
                else:
                    # 避免多进程打印过多日志，仅默认处理
                    answer_letter = "A" 
            except Exception:
                answer_letter = "A"

        # 4. 构建选项字符串供 Prompt 使用
        formatted_options_list = []
        formatted_options_list_api = []
        
        for i, opt in enumerate(options):
            letter = chr(65 + i)
            formatted_options_list.append(f"{letter}) {opt}")
            formatted_options_list_api.append(f"({letter}) {opt}")

        formatted_options_str = "\n".join(formatted_options_list)
        formatted_options_str_api = "\n".join(formatted_options_list_api)

        # 5. 返回新的 doc 结构
        out_doc = {
            "Question": doc["question"],       
            "formatted_options": formatted_options_str, 
            "formatted_options_api": formatted_options_str_api, 
            "answer": answer_letter,          
            "uuid": doc.get("uuid", ""),
        }
        return out_doc

    # --- 性能优化部分 ---
    # 获取 CPU 核心数，预留少量核心给系统
    cpu_count = os.cpu_count()
    if cpu_count is not None and cpu_count > 2:
        num_proc = cpu_count - 1 
    else:
        num_proc = 4 # 默认回退值

    print(f"Starting data processing with {num_proc} processes...")

    # 使用 num_proc 启用多进程并行处理
    return dataset.map(
        _process_doc, 
        num_proc=num_proc, 
        desc="Formatting Options & Answers"
    )

# -----------------------------------------------------------------------------
# 4. 文本生成函数 & 辅助工具
# -----------------------------------------------------------------------------
def doc_to_text_gpqa(doc: dict) -> str:
    return QUERY_TEMPLATE.format(
        Question=doc["Question"], 
        Options=doc["formatted_options"]
    )

def last_boxed_only_string(string: str) -> Optional[str]:
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None
    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    if right_brace_idx is None:
        return None
    return string[idx : right_brace_idx + 1]

def remove_boxed(s: str) -> str:
    if "\\boxed " in s:
        left = "\\boxed "
        return s[len(left) :]
    left = "\\boxed{"
    return s[len(left) : -1]