from dataclasses import dataclass
from datasets import load_dataset,concatenate_datasets
from tqdm import tqdm
from src import prompt,math_utils
import re
from dataclasses import asdict
import random


@dataclass
class Question:
    id:int
    origin_question:str
    correct_answer:str


@dataclass
class QuestionQuery:
    id:str
    origin_question:str
    origin_correct_answer:str
    prompt:str
    def __json__(self):
        """返回可 JSON 序列化的字典"""
        return asdict(self)

@dataclass
class QuestionOutput:
    id:str
    origin_correct_answer:str
    origin_question:str
    prompt:str
    generated_text:str
    def __json__(self):
        """返回可 JSON 序列化的字典"""
        return asdict(self)

@dataclass
class BatchPrompt:
    id_list:list[int]
    origin_question_list:list[str]
    origin_correct_answer_list:list[str]
    prompt:str
    text:str

@dataclass
class BatchOutput:
    id_list:list[int]
    origin_question_list:list[str]
    origin_correct_answer_list:list[str]
    prompt:str
    generated_text:str
    origin_prompt:str
    generated_text_token_len:int
    def __json__(self):
        """返回可 JSON 序列化的字典"""
        return asdict(self)

data_set_dir="/dataset"


def build_distill_prompt(tokenizer,question_list:list[str])->list[str]:
    """
    Build prompts for the distill.
    """
    ret=[]
    for question in question_list:
        chat_prompt=question+"\n\nPlease reason step by step, and put your final answer within \\boxed{}."
        if tokenizer is None:
            text=chat_prompt
        else:
            messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user", "content": chat_prompt}
                ]
            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        ret.append(text)
    return ret

def build_eval_prompt(tokenizer,question_list:list[Question])->list[BatchPrompt]:
    """
    Build prompts for the eval.
    """
    ret=[]
    for question in question_list:
        chat_prompt=question.origin_question+"\n\nPlease reason step by step, and put your final answer within \\boxed{}."
        messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": chat_prompt}
            ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        ret.append(BatchPrompt(
            id_list=[question.id ],
            origin_question_list=[question.origin_question],
            origin_correct_answer_list=[question.correct_answer],
            prompt=chat_prompt,
            text=text,
        ))
    return ret

def build_prompts(tokenizer,questions:list[Question],batch_size=1,language='zh')->list[BatchPrompt]:
    """
    Build prompts for the model.
    """
    ret=[]
    for i in tqdm(range(0, len(questions), batch_size)):
        batch_questions = questions[i:i + batch_size]
        origin_question_list=[q.origin_question for q in batch_questions]
        chat_prompt=prompt.batch_chat_prompt(origin_question_list,language)
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": chat_prompt}
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        ret.append(BatchPrompt(
            id_list=[q.id for q in batch_questions],
            origin_question_list=origin_question_list,
            origin_correct_answer_list=[q.correct_answer for q in batch_questions],
            prompt=chat_prompt,
            text=text,
        ))
    return ret


def load_math_500(data_path=f'{data_set_dir}/HuggingFaceH4/MATH-500')->list[Question]:
    ret=[]
    data=load_dataset(data_path)
    data=data['test']
    for i in range(len(data)):
        question=Question(
            id=i,
            origin_question=data[i]['problem'],
            correct_answer=data[i]['answer']
        )
        ret.append(question)
    return ret

def load_gsm_8k(data_path=f'{data_set_dir}/openai/gsm8k',data_split='train')->list[Question]:
    ret=[]
    data=load_dataset(data_path,'main')
    data=data[data_split]
    for i in range(len(data)):
        match = re.search(r"####\s*(-?\d+)", data[i]['answer'])
        if not match:
            print(data[i]['answer'])
            continue
        question=Question(
            id=i,
            origin_question=data[i]['question'],
            correct_answer=match.group(1)
        )
        ret.append(question)
    return ret

def load_aqua_rat(data_path=f'{data_set_dir}/deepmind/aqua_rat',data_split='train')->list[Question]:
    question_fmt='''{origin_question} The answer options are {options}. Please only output the option letter.'''
    ret=[]
    data=load_dataset(data_path,'raw')
    data=data[data_split]
    for i in range(len(data)):
        question=Question(
            id=i,
            origin_question=question_fmt.format(origin_question=data[i]['question'],options=data[i]['options']),
            correct_answer=data[i]['correct']
        )
        ret.append(question)
    return ret

def load_omni_math(data_path=f'{data_set_dir}/KbsdJames/Omni-MATH',data_split='test')->list[Question]:
    ret=[]
    data=load_dataset(data_path)
    data=data[data_split]
    now=0
    for i in range(len(data)):
        # if not math_utils.is_number(f"{data[i]['answer']}"):
        #         continue
        question=Question(
            id=now+1,
            origin_question=data[i]['problem'],
            correct_answer=data[i]['answer']
        )
        now+=1
        ret.append(question)
    return ret

def load_GAIR_LIMO(data_path=f'{data_set_dir}/GAIR/LIMO',data_split='train')->list[Question]:
    ret=[]
    data=load_dataset(data_path)
    data=data[data_split]
    now=0
    for i in range(len(data)):
        if not math_utils.is_number(f"{data[i]['answer']}"):
                continue
        question=Question(
            id=now+1,
            origin_question=data[i]['question'],
            correct_answer=data[i]['answer']
        )
        now+=1
        ret.append(question)
    return ret


def load_MATH_lighteval_dataset(data_path=f'{data_set_dir}/DigitalLearningGmbH/MATH-lighteval',data_split='all')->list[Question]:
    data=load_dataset(data_path,'default')
    if data_split!='all':
        data=data[data_split]
    return data

def load_MATH_lighteval(data_path=f'{data_set_dir}/DigitalLearningGmbH/MATH-lighteval',data_split='train')->list[Question]:
    ret=[]
    data=load_MATH_lighteval_dataset(data_path,data_split)
    for i in range(len(data)):
        question=Question(
            id=i,
            origin_question=data[i]['problem'],
            correct_answer=data[i]['solution']
        )
        ret.append(question)
    return ret
def load_lighteval_extract_int_dataset(data_path=f'{data_set_dir}/lighteval_extract_int',data_split='all')->list[Question]:
    data=load_dataset(data_path,'default')
    if data_split!='all':
        data=data[data_split]
    return data

def load_lighteval_extract_int(data_path=f'{data_set_dir}/lighteval_extract_int',data_split='train')->list[Question]:
    ret=[]
    data=load_MATH_lighteval_dataset(data_path,data_split)
    for i in range(len(data)):
        question=Question(
            id=i,
            origin_question=data[i]['problem'],
            correct_answer=f"{data[i]['answer']}"
        )
        ret.append(question)
    return ret

def load_aime_2024_dataset(data_path=f'{data_set_dir}/HuggingFaceH4/aime_2024',data_split='train')->list[Question]:
    data=load_dataset(data_path)
    data=data[data_split]
    return data

def load_aime_2024(data_path=f'{data_set_dir}/HuggingFaceH4/aime_2024',data_split='train',rep=10)->list[Question]:
    ret=[]
    data=load_aime_2024_dataset(data_path,data_split)
    for _ in range(rep):
        for i in range(len(data)):
            question=Question(
                id=i,
                origin_question=data[i]['problem'],
                correct_answer=f"{data[i]['answer']}"
            )
            ret.append(question)
    return ret

def load_aime_2025_dataset(data_path=f'{data_set_dir}/opencompass/AIME2025',data_split='test'):
    data1=load_dataset(data_path,'AIME2025-I')
    data1=data1[data_split]
    data2=load_dataset(data_path,'AIME2025-II')
    data2=data2[data_split]
    data=concatenate_datasets([data1,data2])
    return data

def load_aime_2025(data_path=f'{data_set_dir}/opencompass/AIME2025',data_split='test',rep=10)->list[Question]:
    ret=[]
    data=load_aime_2025_dataset(data_path,data_split)
    for _ in range(rep):
        for i in range(len(data)):
            question=Question(
                id=i,
                origin_question=data[i]['question'],
                correct_answer=f"{data[i]['answer']}"
            )
            ret.append(question)
    return ret


def load_deepscaler_dataset(data_path=f'{data_set_dir}/nanoverl/deepscaler',data_split='train')->list[Question]:
    data=load_dataset(data_path)[data_split]
    return data

def load_deepscaler(data_path=f'{data_set_dir}/nanoverl/deepscaler',data_split='train',rep=1)->list[Question]:
    ret=[]
    data=load_deepscaler_dataset(data_path,data_split)
    for _ in range (rep):
        for i in range(len(data)):
            if not math_utils.is_number(f"{data[i]['answer']}"):
                continue
            question=Question(
                id=i,
                origin_question=data[i]['problem'],
                correct_answer=f"{data[i]['answer']}"
            )
            ret.append(question)
    return ret

def load_deepscaler_dataset_v2(data_path=f'{data_set_dir}/agentica-org/DeepScaleR-Preview-Dataset',data_split='train')->list[Question]:
    data=load_dataset(data_path)[data_split]
    return data

def load_deepscaler_v2(data_path=f'{data_set_dir}/agentica-org/DeepScaleR-Preview-Dataset',data_split='train',rep=1)->list[Question]:
    ret=[]
    data=load_deepscaler_dataset(data_path,data_split)
    for _ in range (rep):
        now=0
        for i in range(len(data)):
            if not math_utils.is_number(f"{data[i]['answer']}"):
                continue
            question=Question(
                id=now+1,
                origin_question=data[i]['problem'],
                correct_answer=f"{data[i]['answer']}"
            )
            now+=1 
            ret.append(question)
    return ret

def load_OpenR1_Math(data_path=f'{data_set_dir}/open-r1/OpenR1-Math-220k',data_split='train')->list[Question]:
    data=load_dataset(data_path,'default')[data_split]
    return data

def load_DeepMath_103K_dataset(data_path=f'{data_set_dir}/zwhe99/DeepMath-103K',data_split='train')->list[Question]:
    data=load_dataset(data_path)[data_split]
    return data

def load_DeepMath_103K(data_path=f'{data_set_dir}/zwhe99/DeepMath-103K',data_split='train',rep=1)->list[Question]:
    ret=[]
    data=load_deepscaler_dataset(data_path,data_split)
    for _ in range (rep):
        for i in range(len(data)):
            if not math_utils.is_number(f"{data[i]['final_answer']}"):
                continue
            question=Question(
                id=i,
                origin_question=data[i]['question'],
                correct_answer=f"{data[i]['final_answer']}"
            )
            ret.append(question)
    return ret

def load_AMC_23(data_path=f'{data_set_dir}/knoveleng/AMC-23',data_split='train',rep=10)->list[Question]:
    ret=[]
    data=data=load_dataset(data_path)[data_split]
    for _ in range (rep):
        for i in range(len(data)):
            question=Question(
                id=i,
                origin_question=data[i]['problem'],
                correct_answer=f"{data[i]['answer']}"
            )
            ret.append(question)
    return ret

def load_GPQA_diamond(data_path=f'{data_set_dir}/Idavidrein/gpqa',data_split='train',rep=1)->list[Question]:
    ret=[]
    data=data=load_dataset(data_path,'gpqa_diamond')[data_split]
    for _ in range (rep):
        for i in range(len(data)):
            answer=random.choice(['A','B','C','D'])
            append_text=f"Please choose the correct answer from among the following options: \n"
            idx=1
            for option in ['A','B','C','D']:
                if option==answer:
                    append_text+=f"{option}: {data[i]['Correct Answer']}\n"
                else:
                    field=f"Incorrect Answer {idx}"
                    append_text+=f"{option}: {data[i][field]}\n"
                    idx+=1
            question_text=f"{data[i]['Question']}\n{append_text}"
            question=Question(
                id=i,
                origin_question=question_text,
                correct_answer=f"{answer}"
            )
            ret.append(question)
    return ret


loaderFuncDict={
    'math_500':load_math_500,
    'gsm_8k':load_gsm_8k,
    'aqua_rat':load_aqua_rat,
    'omni_math':load_omni_math,
    'GAIR_LIMO':load_GAIR_LIMO,
    'lighteval_extract_int':load_lighteval_extract_int,
    'deepscaler_v2':load_deepscaler_v2,
    'DeepMath_103K_dataset':load_DeepMath_103K_dataset,
    'DeepMath_103K':load_DeepMath_103K,
    'AMC_23':load_AMC_23,
    'GPQA_diamond':load_GPQA_diamond,
}