import os
import sys
import re
import random
from modelscope.msdatasets import MsDataset
from .eval_math import math_equal
from tqdm import tqdm
from .math_utils import math_is_correct

from .src.xVerify.model import Model
from .src.xVerify.eval import Evaluator

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)

additional_prompt = {
    "raw" : "",
    "previous" : " the answer should start with \"Final Answer\".",
    "better" : " Please reason step by step, and put your final answer within \\boxed{}.",
    "answer" : " The answer should start with \"The answer is:\"."
}

class MathDataset(MsDataset):
    def __init__(self, dataset_path, prompt_type="better", shuffle=True, xverify_path=None, device="cuda:0"):
        original_ds = MsDataset.load(dataset_path, subset_name='default', split='test')
        if shuffle:
            random_state = random.getstate()
            random.seed(42)
            dataset_list = list(original_ds)
            random.shuffle(dataset_list)
            self.dataset = dataset_list
            random.setstate(random_state)
        else:
            self.dataset = original_ds
        self.prompt_type = prompt_type
        if xverify_path is not None:
            self.model = Model(model_name = "xVerify-9B-C",
                               model_path_or_url = xverify_path,
                               inference_mode="local",
                               device=device
                                )
            self.evaluator = Evaluator(self.model)

    def __getitem__(self, index):
        return self.dataset[index]
    
    def __len__(self):
        return len(self.dataset)
    
    def get_prompt(self, index=None):
        if index is None:
            return [self.dataset[i]['problem'] + additional_prompt[self.prompt_type] for i in range(len(self.dataset))]
        else:
            return [self.dataset[i]['problem'] + additional_prompt[self.prompt_type] for i in range(index + 1)]

    def result_eval(self, answer):
        correct_count = 0
        total_count = len(answer)
        
        for i in range(total_count):
            if math_equal(answer[i], self.dataset[i]['solution']):
                correct_count += 1

        return correct_count / total_count

    def result_eval_reward(self, answer, thinking_chain, reward_func, system_prompt):
        correct_count = 0
        total_count = len(answer)
        correct_reward = []
        wrong_reward = []

        for i in range(total_count):
            attempt = answer[i]
            correct_answer = self.dataset[i]['solution']
            if attempt == "Error Answer":
                continue
            elif math_equal(attempt, correct_answer):
                correct_count += 1
                correct_reward.append(reward_func(system_prompt, self.dataset[i]['problem'] + additional_prompt[self.prompt_type], thinking_chain[i]))
            else:
                wrong_reward.append(reward_func(system_prompt, self.dataset[i]['problem'] + additional_prompt[self.prompt_type], thinking_chain[i]))
        
        ret = {
            "accuracy": correct_count / total_count,
            "correct_reward": correct_reward,
            "wrong_reward": wrong_reward
        }

        return ret

    def eval_xverify(self, answer):
        correct_count = 0
        correct_list = [] 
        for i in tqdm(range(len(answer)), desc="Evaluating xVerify"):
            question = self.dataset[i]['problem']
            llm_output = answer[i]
            correct_answer = self.dataset[i]['solution']
            result = self.evaluator.single_evaluate(question=question,
                                                    llm_output=llm_output,
                                                    correct_answer=correct_answer,
                                                    )
            if result == "Correct":
                correct_count += 1
                correct_list.append(1)
            else:
                correct_list.append(0)
        return correct_count / len(answer), correct_list

    def eval_math_is_correct(self, answer):
        correct_count = 0
        correct_list = []
        for i in tqdm(range(len(answer)), desc="Evaluating math_is_correct"):
            question = self.dataset[i]['problem']
            llm_output = answer[i]
            correct_answer = self.dataset[i]['solution']
            result = math_is_correct(llm_output, correct_answer)
            if result == 1:
                correct_count += 1
                correct_list.append(1)
            else:
                correct_list.append(0)
        return correct_count / len(answer), correct_list

if __name__ == '__main__':
    ds = MathDataset(dataset_path='/data1/efficient-reasoning/competition_math', xverify_path="")
    # print(ds.get_prompt(index=0))
    print(ds.dataset[0])
    accuracy = ds.eval_math_is_correct(['the answer is $\\boxed{\\pi-2}$ vertical asymptotes'])
    print(accuracy[0])