import re
import json
import multiprocessing

from tqdm import tqdm

from .eval import BaseEvaluator
from utils.analyze_answer import compare_answer_golden
from utils import build_prompt, generate_hash, check_empty_result, get_client, get_response
from common.constants import REASONING_PROCESS_REGS, ANSWER_REGS
from datasets import load_dataset


class MMLUReduxEvaluator(BaseEvaluator):
    def package_params(self):
        self.logger.info("Concatenating mmlu-redux from different subsets")
        dataset_list = []
        with tqdm(total=len(self.config_names), desc='Concatenating mmlu-redux from different subsets') as progress_bar:
            with multiprocessing.Pool(self.num_workers if self.num_workers <= 3 else 3) as pool: # set num_workers to 3 to avoid Too Many Requests Error
                for index, dataset in enumerate(pool.imap_unordered(self.load_dataset_from_configs, list(self.config_names))):
                    dataset_list += dataset
                    progress_bar.update(1)
        
        self.logger.info("Packaging Params")
        with tqdm(total=len(dataset_list), desc='Packaging Params') as progress_bar:
            for i, data in enumerate(dataset_list):
                question = data['question']
                full_prompt = build_prompt(
                    question,
                    include_cot=self.cot_reasoning,
                    question_type=self.question_type,
                    answer_type=self.answer_type,
                )
                prompt = full_prompt['prompt']
                golden = data['answer']
                subset = data['subset']
                question_hash = generate_hash(question)

                result_path = self.raw_results_dir.joinpath(f"{question_hash}.json")
                if result_path.exists():
                    with open(result_path, "r") as f:
                        existing_data = json.load(f)
                    if check_empty_result(existing_data):
                        self.params.append((prompt, question, golden, subset, question_hash,
                                        self.model_name, self.temperature, self.top_p))
                elif not any(param[4] == question_hash for param in self.params):
                    self.params.append((prompt, question, golden, subset, question_hash,
                                    self.model_name, self.temperature, self.top_p))
                else:
                    self.logger.warning(f"Duplicate hash skipped: {question_hash}")
                progress_bar.update(1)
            self.logger.info(f"Total params: {len(self.params)}")


    def run_query(self):
        debug_flag = True
        success = 0
        failure = 0
        self.logger.info("Getting Model's Responses")
        with multiprocessing.Pool(self.num_workers) as pool:
            with tqdm(total=len(self.params), desc='Querying Models') as progress_bar:
                for index, (question_hash, result, query_settings, query_flag) in enumerate(pool.imap_unordered(self.query_pipeline, self.params), start=1):
                    result_to_save = {
                        str(question_hash): {
                            "result": result,
                        }
                    }
                    with open(self.raw_results_dir.joinpath(f"{question_hash}.json"), "w") as f:
                        json.dump(result_to_save, f, indent=4)
                    if query_flag:
                        success += 1
                    else:
                        failure += 1
                    progress_bar.set_description(f'Querying - {question_hash}')
                    progress_bar.update(1)
                    if debug_flag:
                        self.logger.info(f'query_settings: {query_settings}')
                        debug_flag = False
        self.logger.info("Getting Responses Complete")
        self.logger.info(f"Success: {success}, Failure: {failure}, Total: {len(self.params)}")


    def query_pipeline(self, params):
        query_flag = False
        query_settings = None
        try:
            prompt, question, golden, subset, question_hash, model_name, temperature, top_p = params
            client = get_client(
                model_name=model_name,
            )
            model_response, model_intrinsic_reasoning, usage, query_settings = get_response(
                                                                    client=client,
                                                                    model_name=model_name,
                                                                    prompt=prompt,
                                                                    temperature=temperature,
                                                                    top_p=top_p,
                                                                    max_tokens=self.max_tokens,
                                                                    thinking_budget=self.thinking_budget,
                                                                    enable_intrinsic_reasoning=self.enable_intrinsic_reasoning,
                                                                )
            
            reasoning_process_match = re.search(REASONING_PROCESS_REGS, model_response, re.DOTALL)
            answer_match = re.search(ANSWER_REGS, model_response, re.DOTALL)                
            reasoning_process_text = reasoning_process_match.group(1) if reasoning_process_match else None
            answer_text = answer_match.group(1) if answer_match else None
            result = {
                "prompt": prompt,
                "question": question,
                "model_response": model_response,
                "reasoning_process": reasoning_process_text.strip() if reasoning_process_text else None,
                "intrinsic_reasoning": model_intrinsic_reasoning.strip() if model_intrinsic_reasoning else None,
                "enable_intrinsic_reasoning": self.enable_intrinsic_reasoning,
                "cot_reasoning": self.cot_reasoning,
                "model_answer": answer_text.strip() if answer_text else None,
                "golden": golden,
                "subset": subset,
                "model_name": model_name,
                "task": "general",
                "temperature": temperature,
                "top_p": top_p,
                "usage": usage,
            }
            query_flag = True
        except Exception as e:
            result = {}
            query_settings if query_settings is not None else {} 
            self.logger.error(f"Querying-{question_hash}: Error: {e}")
        return question_hash, result, query_settings, query_flag


    def process_results(self):
        self.logger.info("Processing Results")
        acc = 0
        total = len(list(self.raw_results_dir.iterdir())) 
        processed_results_list = list()
        processed_error_results_list = list()
        with tqdm(total=len(list(self.raw_results_dir.iterdir())), desc='Processing Results') as progress_bar:
            for result_file in self.raw_results_dir.iterdir():
                if result_file.is_file() and result_file.suffix == '.json':
                    question_hash = str(result_file.stem)
                    with open(result_file, "r") as f:
                        result_data = json.load(f)
                    result_data[question_hash]['evaluation'] = compare_answer_golden(question_hash=question_hash, 
                                                                                     dataset="mmlu_redux", 
                                                                                     model_answer=result_data[question_hash]['result']['model_answer'], 
                                                                                     golden_answer=result_data[question_hash]['result']['golden'],
                                                                                     logger=self.logger)
                    if result_data[question_hash]['evaluation']['correct']:
                        acc += 1
                    else:
                        processed_error_results_list.append(result_data)
                    processed_results_list.append(result_data)
                    progress_bar.update(1)
        with open(self.processed_results_dir.joinpath(f"processed_results.json"), "w") as f:
            json.dump(processed_results_list, f, indent=4)
        with open(self.processed_results_dir.joinpath(f"processed_error_results.json"), "w") as f:
            json.dump(processed_error_results_list, f, indent=4)
        self.logger.info(f"Accuracy: {acc / total}")
        self.logger.info("Processing Results Complete")
        

    def prepare_dataset(self, origin_dataset, config_name):
        dataset = [] 
        int2answer = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H', 8: 'I', 9: 'J'}
        Choice2answer = {'A' : 'A', 'B' : 'B', 'C' : 'C', 'D' : 'D', 'E' : 'E', 'F' : 'F', 'G' : 'G', 'H' : 'H', 'I' : 'I'}
        for data in origin_dataset:
            error_type = data['error_type']
            if error_type == 'ok':
                answer = int2answer[int(data['answer'])] 
            elif error_type == 'wrong_groundtruth' and data['correct_answer'] is not None:
                try:
                    answer = Choice2answer[data['correct_answer']]
                except:
                    answer = int2answer[int(data['correct_answer'])]
            else:
                continue
            question = data['question']
            origin_choices = data['choices']
            choices = [f"{int2answer[i]}. {item}" for i, item in enumerate(origin_choices)]
            question_with_options = question + "\n" + "\n".join(choices)
            dataset.append({
                'question': question_with_options,
                'answer': answer,
                'subset': config_name
            })
        return dataset 
    

    def load_dataset_from_configs(self, config_name):
        dataset = load_dataset(self.dataset_name_from_hf, config_name, cache_dir=self.huggingface_cache)
        dataset = self.prepare_dataset(dataset['test'], config_name)
        return dataset