import os
import re
import json
import multiprocessing
import regex

from tqdm import tqdm

from .eval import BaseEvaluator
from utils.analyze_answer import compare_answer_golden
from utils import build_prompt, generate_hash, check_empty_result, get_client, get_response
from common.constants import REASONING_PROCESS_REGS, ANSWER_REGS
from datasets import load_dataset

class MATHEvaluator(BaseEvaluator):
    def package_params(self):
        self.logger.info("Packaging Params")
        dataset = load_dataset(self.dataset_name_from_hf, cache_dir=self.huggingface_cache)
        with tqdm(total=len(dataset), desc='Packaging Params') as progress_bar:
            for i, data in enumerate(dataset['test']):
                question = data['problem'] 
                full_prompt = build_prompt(
                    question,
                    include_cot=self.cot_reasoning,
                    question_type=self.question_type,
                    answer_type=self.answer_type,
                )
                prompt = full_prompt['prompt']
                golden = data['answer']
                golden_solution = data['solution']
                subject = data['subject']
                level = data['level']
                question_hash = generate_hash(question)
                result_path = self.raw_results_dir.joinpath(f"{question_hash}.json")

                if not result_path.exists():
                    self.params.append((prompt, question, golden, golden_solution, subject, level, question_hash,
                                        self.model_name, self.temperature, self.top_p))
                else:
                    with open(result_path, "r") as f:
                        existing_data = json.load(f)
                    if check_empty_result(existing_data):
                        self.params.append((prompt, question, golden, golden_solution, subject, level, question_hash,
                                            self.model_name, self.temperature, self.top_p))
                progress_bar.update(1)


    def run_query(self):
        debug_flag = True
        success = 0
        failure = 0
        self.logger.info("Getting Model's Responses")
        with multiprocessing.Pool(self.num_workers) as pool:
            with tqdm(total=len(self.params), desc='Querying Models') as progress_bar:
                for index, (question_hash, result, query_settings, query_flag) in enumerate(pool.imap_unordered(self.query_pipeline, self.params), start=1):
                    result_to_save = {
                        str(question_hash): {
                            "result": result,
                        }
                    }
                    with open(self.raw_results_dir.joinpath(f"{question_hash}.json"), "w") as f:
                        json.dump(result_to_save, f, indent=4)
                    if query_flag:
                        success += 1
                    else:
                        failure += 1
                    progress_bar.set_description(f'Querying - {question_hash}')
                    progress_bar.update(1)
                    if debug_flag:
                        self.logger.info(f'query_settings: {query_settings}')
                        debug_flag = False
        self.logger.info("Getting Responses Complete")
        self.logger.info(f"Success: {success}, Failure: {failure}, Total: {len(self.params)}")


    def query_pipeline(self, params):
        query_flag = False
        query_settings = None
        try:
            prompt, question, golden, golden_solution, subject, level, question_hash, model_name, temperature, top_p = params
            client = get_client(
                model_name=model_name,
            )
            model_response, model_intrinsic_reasoning, usage, query_settings = get_response(
                                                                    client=client,
                                                                    model_name=model_name,
                                                                    prompt=prompt,
                                                                    temperature=temperature,
                                                                    top_p=top_p,
                                                                    max_tokens=self.max_tokens,
                                                                    thinking_budget=self.thinking_budget,
                                                                    enable_intrinsic_reasoning=self.enable_intrinsic_reasoning,
                                                                )
            

            reasoning_process_match = re.search(REASONING_PROCESS_REGS, model_response, re.DOTALL)
            answer_match = re.search(ANSWER_REGS, model_response, re.DOTALL)                
            reasoning_process_text = reasoning_process_match.group(1) if reasoning_process_match else None
            answer_text = answer_match.group(1) if answer_match else None
            result = {
                "prompt": prompt,
                "question": question,
                "model_response": model_response,
                "reasoning_process": reasoning_process_text.strip() if reasoning_process_text else None,
                "intrinsic_reasoning": model_intrinsic_reasoning.strip() if model_intrinsic_reasoning else None,
                "enable_intrinsic_reasoning": self.enable_intrinsic_reasoning,
                "cot_reasoning": self.cot_reasoning,
                "model_answer": answer_text.strip() if answer_text else None,
                "golden": golden,
                "golden_solution": golden_solution,
                "subject": subject,
                "level": level,
                "model_name": model_name,
                "task": "math",
                "temperature": temperature,
                "top_p": top_p,
                "usage": usage,
            }
            query_flag = True
        except Exception as e:
            result = {}
            query_settings if query_settings is not None else {} 
            self.logger.error(f"Querying-{question_hash}: Error: {e}")
        return question_hash, result, query_settings, query_flag


    def process_results(self):
        self.logger.info("Processing Results")
        acc = 0
        total = len(list(self.raw_results_dir.iterdir())) 
        processed_results_list = list()
        processed_error_results_list = list()
        with tqdm(total=len(list(self.raw_results_dir.iterdir())), desc='Processing Results') as progress_bar:
            for result_file in self.raw_results_dir.iterdir():
                if result_file.is_file() and result_file.suffix == '.json':
                    question_hash = str(result_file.stem)
                    with open(result_file, "r") as f:
                        result_data = json.load(f)
                    result_data[question_hash]['result']['golden'] = '\\boxed{' + result_data[question_hash]['result']['golden'] + '}'
                    result_data[question_hash]['evaluation'] = compare_answer_golden(question_hash=question_hash, 
                                                                                     dataset='MATH-500', 
                                                                                     model_answer=result_data[question_hash]['result']['model_answer'] if result_data[question_hash]['result']['model_answer'] else \
                                                                                      (regex.findall(r'(\\boxed\{((?:[^{}]+|(?R))*)\})', result_data[question_hash]['result']['model_response']) or [(None, None)])[-1][0], 
                                                                                     golden_answer=result_data[question_hash]['result']['golden'],
                                                                                     logger=self.logger)
                    if result_data[question_hash]['evaluation']['correct']:
                        acc += 1
                    else:
                        processed_error_results_list.append(result_data)
                    processed_results_list.append(result_data)
                    progress_bar.update(1)
        with open(self.processed_results_dir.joinpath(f"processed_results.json"), "w") as f:
            json.dump(processed_results_list, f, indent=4)
        with open(self.processed_results_dir.joinpath(f"processed_error_results.json"), "w") as f:
            json.dump(processed_error_results_list, f, indent=4)
        self.logger.info(f"Accuracy: {acc / total}")
        self.logger.info("Processing Results Complete")