import pathlib
import click
import json
import re
import regex
import copy
import pandas as pd

from tqdm import tqdm

from utils.analyze_answer import compare_answer_golden
from utils.logger import get_logger
from common.constants import *

class BaseProcessor:
    def __init__(self, raw_results_dir: pathlib.Path, processed_results_dir: pathlib.Path, dataset, model_info):
        self.raw_results_dir = raw_results_dir
        self.processed_results_dir = processed_results_dir
        self.dataset = dataset
        self.model_info = model_info
        self.logger = get_logger()

    def process_results(self):
        raise NotImplementedError("Subclasses should implement this method.")
    

class MATHProcessor(BaseProcessor):
    def process_results(self):
        self.logger.info("Processing Results")
        acc = 0
        total = len(list(self.raw_results_dir.iterdir())) 
        processed_results_list = list()
        processed_error_results_list = list()
        with tqdm(total=len(list(self.raw_results_dir.iterdir())), desc='Processing Results') as progress_bar:
            for result_file in self.raw_results_dir.iterdir():
                if result_file.is_file() and result_file.suffix == '.json':
                    question_hash = str(result_file.stem)
                    with open(result_file, "r") as f:
                        result_data = json.load(f)
                    result_data[question_hash]['result']['golden'] = '\\boxed{' + result_data[question_hash]['result']['golden'] + '}'
                    model_answer = result_data[question_hash]['result']['model_answer'] if result_data[question_hash]['result']['model_answer'] else \
                                    (regex.findall(r'(\\boxed\{((?:.+|(?R))*)\})', result_data[question_hash]['result']['model_response']) or [(None, None)])[-1][0]
                    if result_data[question_hash]['result']['model_answer'] is None:
                        result_data[question_hash]['result']['model_answer'] = model_answer
                    if result_data[question_hash]['result']['model_answer'] is None:
                        model_answer = result_data[question_hash]['result']['model_response']
                        result_data[question_hash]['result']['model_answer'] = model_answer

                    result_data[question_hash]['evaluation'] = compare_answer_golden(question_hash=question_hash, 
                                                                                     dataset='MATH-500', 
                                                                                     model_answer=model_answer, 
                                                                                     golden_answer=result_data[question_hash]['result']['golden'],
                                                                                     logger=self.logger)
                    if result_data[question_hash]['evaluation']['correct']:
                        acc += 1
                    else:
                        processed_error_results_list.append(result_data)
                    processed_results_list.append(result_data)
                    progress_bar.update(1)
        with open(self.processed_results_dir.joinpath(f"processed_results.json"), "w") as f:
            json.dump(processed_results_list, f, indent=4)
        with open(self.processed_results_dir.joinpath(f"processed_error_results.json"), "w") as f:
            json.dump(processed_error_results_list, f, indent=4)
        self.logger.info(f"Accuracy: {acc / total}")
        self.logger.info("Processing Results Complete")


class MMLUReduxProcessor(BaseProcessor):
    def process_results(self):
        self.logger.info("Processing Results")
        acc = 0
        total = len(list(self.raw_results_dir.iterdir())) 
        processed_results_list = list()
        processed_error_results_list = list()
        with tqdm(total=len(list(self.raw_results_dir.iterdir())), desc='Processing Results') as progress_bar:
            for result_file in self.raw_results_dir.iterdir():
                if result_file.is_file() and result_file.suffix == '.json':
                    question_hash = str(result_file.stem)
                    with open(result_file, "r") as f:
                        result_data = json.load(f)

                    model_response = result_data[question_hash]['result']['model_response']
                    answer_match = re.findall(ANSWER_REGS, model_response, re.DOTALL)
                    answer_text = answer_match[-1] if answer_match else None
                    result_data[question_hash]['result']['model_answer'] = answer_text.strip() if answer_text is not None else None

                    regexes = [r"\\boxed\{(.+?)\}", r"(?:A|a)nswer is ?\*{0,2} ?([A-D])\b", r"(?:A|a)nswer\"?: ?\*{0,2} ?\"?([A-D])\b", r"(?:A|a)nswer\n{0,2}([A-D])\b", r"\{([A-D])\}", r"\\text\{([A-D])\b", r"^ ?([A-D])\. ", r"\* ?([A-D])\b", r"\n? ?([A-D])\.?"]

                    for regex in regexes:

                        if regex == r"\n? ?([A-D])\.?":
                            str_to_match = model_response if result_data[question_hash]['result']['model_answer'] is None else result_data[question_hash]['result']['model_answer']
                            match = re.fullmatch(regex, str_to_match)
                            if match:
                                result_data[question_hash]['result']['model_answer'] = match.group(1)

                        else:
                            if result_data[question_hash]['result']['model_answer'] is None:
                                all_matches = re.findall(regex, model_response, re.DOTALL)
                                if regex in [r"\{([A-D])\}", r"\\text\{([A-D])\b", r"^ ?([A-D])\. ", r"\* ?([A-D])\b"]:
                                    if len(all_matches) == 1:
                                        result_data[question_hash]['result']['model_answer'] = all_matches[0]
                                elif all_matches:
                                    result_data[question_hash]['result']['model_answer'] = all_matches[-1]
                            elif len(result_data[question_hash]['result']['model_answer']) != 1:
                                all_matches = re.findall(regex, result_data[question_hash]['result']['model_answer'], re.DOTALL)
                                if len(all_matches) == 1:
                                    result_data[question_hash]['result']['model_answer'] = all_matches[0]

                    if result_data[question_hash]['result']['model_answer'] is not None and len(result_data[question_hash]['result']['model_answer']) != 1:
                        result_data[question_hash]['result']['model_answer'] = None

                    result_data[question_hash]['evaluation'] = compare_answer_golden(question_hash=question_hash, 
                                                                                     dataset="mmlu_redux", 
                                                                                     model_answer=result_data[question_hash]['result']['model_answer'], 
                                                                                     golden_answer=result_data[question_hash]['result']['golden'],
                                                                                     logger=self.logger)
                    if result_data[question_hash]['evaluation']['correct']:
                        acc += 1
                    else:
                        processed_error_results_list.append(result_data)
                    processed_results_list.append(result_data)
                    progress_bar.update(1)
        with open(self.processed_results_dir.joinpath(f"processed_results.json"), "w") as f:
            json.dump(processed_results_list, f, indent=4)
        with open(self.processed_results_dir.joinpath(f"processed_error_results.json"), "w") as f:
            json.dump(processed_error_results_list, f, indent=4)
        self.logger.info(f"Accuracy: {acc / total}")
        self.logger.info("Processing Results Complete")


class SimpleQAProcessor(BaseProcessor):
    def process_results(self):
        self.logger.info("Processing Results")
        acc = 0
        total = len(list(self.raw_results_dir.iterdir())) 
        processed_results_list = list()
        processed_error_results_list = list()
        with tqdm(total=len(list(self.raw_results_dir.iterdir())), desc='Processing Results') as progress_bar:
            for result_file in self.raw_results_dir.iterdir():
                if result_file.is_file() and result_file.suffix == '.json':
                    question_hash = str(result_file.stem)
                    with open(result_file, "r") as f:
                        result_data = json.load(f)

                    if str(result_data[question_hash]['evaluation']['llm_evaluation_decision']).strip() == "A":
                        result_data[question_hash]['evaluation']['correct'] =  True
                        acc += 1
                    else:
                        result_data[question_hash]['evaluation']['correct'] = False
                        processed_error_results_list.append(result_data)
                    processed_results_list.append(result_data)
                    progress_bar.update(1)
        with open(self.processed_results_dir.joinpath(f"processed_results.json"), "w") as f:
            json.dump(processed_results_list, f, indent=4)
        with open(self.processed_results_dir.joinpath(f"processed_error_results.json"), "w") as f:
            json.dump(processed_error_results_list, f, indent=4)
        self.logger.info(f"Accuracy: {acc / total}")
        self.logger.info("Processing Results Complete")


TOKEN_USAGE_MAPPING = [
    (set(OPENAI_LINEAGE) | set(QWEN_LINEAGE) | set(DEEPSEEK_LINEAGE) | set(MINIMAX_LINEAGE) | set(KIMI_LINEAGE) | set(GLM_LINEAGE) | set(SPARK_LINEAGE) | set(DOUBAO_LINEAGE),
        {"input": "prompt_tokens", "output": "completion_tokens"}),
    (set(CLAUDE_LINEAGE), {"input": "input_tokens", "output": "output_tokens"}),
    (set(GEMINI_LINEAGE), {"input": "prompt_tokens", "output": "output_tokens"}),
    (set(GROK_LINEAGE), {"input": "prompt_tokens", "output": "completion_tokens"}),
]

def get_output_token_field(model_name):
    for model_set, field_map in TOKEN_USAGE_MAPPING:
        if model_name in model_set:
            return field_map["output"]
    return None


@click.command()
@click.option("--raw-root-dir", required=True, type=str, help="Directory of the raw results.")
@click.option("--processed-root-dir", required=True, type=str, help="Directory to save processed results.")
@click.option("--output-dir", required=True, type=str, help="Directory to save other output results.")
@click.option("--log-dir", required=True, type=click.Path(path_type=pathlib.Path, exists=True), help="Directory to save logs (default: None).")
def main(
    raw_root_dir,
    processed_root_dir,
    output_dir,
    log_dir
    ):

    PROCESSOR_REGISTRY = {
        'MATH-500': MATHProcessor,
        'mmlu-redux': MMLUReduxProcessor,
        'simple-qa': SimpleQAProcessor,
    }

    dict_template = {
        'MATH-500': {},
        'mmlu-redux': {},
        'simple-qa': {},
    }

    def set_logger(model_name, dataset_name, temperature, top_p, cot_reasoning, enable_intrinsic_reasoning, log_dir):
        logger = get_logger(model_name, dataset_name, temperature, top_p, cot_reasoning, log_dir)
        if model_name in DEFAULT_TEMPERATURE_TOP_P_MODELS or enable_intrinsic_reasoning:
            temperature = "Default"
            top_p = "Default"
            max_tokens = 28672 if model_name not in QWEN_LINEAGE else 8192
            thinking_budget = 20480
            logger.info(f"Using default temperature, top-p, max-tokens, and thinking-budget for {model_name}.")

    raw_root_dir = pathlib.Path(raw_root_dir)
    processed_root_dir = pathlib.Path(processed_root_dir)
    output_dir = pathlib.Path(output_dir)
    if not output_dir.exists():
        output_dir.mkdir()

    over_token_dict = copy.deepcopy(dict_template)
    refuse_answering_dict = copy.deepcopy(dict_template)
    error_format_dict = copy.deepcopy(dict_template)
    token_num_dict = copy.deepcopy(dict_template)
    
    for dataset in ['MATH-500', 'mmlu-redux', 'simple-qa']:
        raw_dataset_dir = raw_root_dir / dataset
        for raw_results_dir in list(raw_dataset_dir.iterdir()):
            model_info = raw_results_dir.name
            processed_results_dir = processed_root_dir / dataset / model_info

            with open(list(raw_results_dir.iterdir())[0], "r") as f:
                f_data = json.load(f)
            question_hash = list(f_data.keys())[0]
            result = f_data[question_hash]['result']
            model_name = result['model_name']
            temperature = result['temperature']
            top_p = result['top_p']
            cot_reasoning = result['cot_reasoning']
            intrinsic_reasoning = result['intrinsic_reasoning']

            set_logger(model_name, dataset, temperature, top_p, cot_reasoning, intrinsic_reasoning, log_dir)

            processor_cls = PROCESSOR_REGISTRY[dataset]
            processor: BaseProcessor = processor_cls(
                raw_results_dir, processed_results_dir, dataset, model_info
            )
            processor.process_results()
            
            processed_file_path = processed_results_dir / 'processed_results.json'

            if 'intrinsic-reasoning' in str(processed_results_dir):
                max_token = 28672
            else:
                max_token = 8192
            with open(processed_file_path, 'r') as file:
                data = json.load(file)
            for item in data:
                hash_key = list(item.keys())[0]
                model_name = item[hash_key]['result']['model_name']
                model_response = item[hash_key]['result']['model_response']
                usage = item[hash_key]['result']['usage']
                correct = item[hash_key]['evaluation']['correct']
                output_token_field = get_output_token_field(model_name)
                over_token = False
                refuse_answering = False
                
                # Counting responses over token limits 
                if output_token_field in usage:
                    if model_name in GEMINI_LINEAGE:
                        output_token = (usage.get("thoughts_tokens", 0) or 0) + (usage.get("output_tokens", 0) or 0)
                    elif model_name in GROK_LINEAGE:
                        output_token = (usage.get("reasoning_tokens",0) or 0) + (usage.get("output_tokens", 0) or 0)
                    else:
                        output_token = usage[output_token_field]
                    if output_token is not None and output_token >= max_token and correct == False:
                        over_token = True
                        if model_info in over_token_dict[dataset]:
                            over_token_dict[dataset][model_info] += 1
                        else:
                            over_token_dict[dataset][model_info] = 1

                # Counting questions that model refuses to answer
                if model_response == 'sensitive content' or 'error' in usage:
                    refuse_answering = True
                    if model_info in refuse_answering_dict[dataset]:
                        refuse_answering_dict[dataset][model_info] += 1
                    else:
                        refuse_answering_dict[dataset][model_info] = 1

                # Counting responses that do not meet the template requirements
                if over_token == False and refuse_answering == False:
                    reasoning_process_text = item[hash_key]['result']['reasoning_process']
                    answer_match = re.search(ANSWER_REGS, model_response, re.DOTALL)
                    answer_text = answer_match.group(1) if answer_match else None
                    answer_good_format = False
                    if answer_text is not None:
                        if dataset == 'mmlu-redux':
                            match = re.fullmatch(r"(?:\*\* ?)?\{?[A-Z]\}?\.?(?: ?\*\*)?", answer_text.strip())
                            if match:
                                answer_good_format = True
                        elif dataset == 'MATH-500':
                            match = re.fullmatch(r"(?:\*\* ?)?\${0,2} ?\n?(?:\\\[)?\n?(?:\\\()?\\boxed\{.*\}(?:\\\))?\n?(?:\\\])?\n? ?\${0,2}(?: ?\*\*)?", answer_text.strip(), re.DOTALL)
                            if match:
                                answer_good_format = True
                        else:
                            answer_good_format = True

                    if (cot_reasoning and reasoning_process_text == None) or answer_good_format == False:
                        if model_info in error_format_dict[dataset]:
                            error_format_dict[dataset][model_info] += 1
                        else:
                            error_format_dict[dataset][model_info] = 1
                
                    # Recording the number of output tokens
                    if model_info not in token_num_dict[dataset]:
                        token_num_dict[dataset][model_info] = {hash_key: output_token}
                    else:
                        token_num_dict[dataset][model_info][hash_key] = output_token

    over_token_list = [
        {
            'dataset': dataset,
            'model_info': model_info,
            'over_token_num': over_token_dict[dataset].get(model_info, 0)
        }
        for dataset in over_token_dict
        for model_info in over_token_dict[dataset]
    ]
    over_token_path = output_dir / 'over_token.csv'
    df_over_token = pd.DataFrame(over_token_list)
    df_over_token.to_csv(over_token_path, index=False)

    refuse_answering_list = [
        {
            'dataset': dataset,
            'model_info': model_info,
            'sensitive_content_num': refuse_answering_dict[dataset].get(model_info, 0),
        }
        for dataset in refuse_answering_dict
        for model_info in refuse_answering_dict[dataset]
    ]
    refuse_answering_path = output_dir / 'refuse_answering.csv'
    df_refuse_answering = pd.DataFrame(refuse_answering_list)
    df_refuse_answering.to_csv(refuse_answering_path, index=False)

    error_format_list = [
        {
            'dataset': dataset,
            'model_info': model_info,
            'error_format_num': error_format_dict[dataset].get(model_info, 0),
        }
        for dataset in error_format_dict
        for model_info in error_format_dict[dataset]
    ]
    error_format_path = output_dir / 'error_format.csv'
    df_error_format = pd.DataFrame(error_format_list)
    df_error_format.to_csv(error_format_path, index=False)

    token_num_list = [
        {
            'dataset': dataset,
            'model_info': model_info,
            'question_hash': question_hash,
            'token_num': token_num_dict[dataset][model_info].get(question_hash, 0),
        }
        for dataset in token_num_dict
        for model_info in token_num_dict[dataset]
        for question_hash in token_num_dict[dataset][model_info]
    ]
    token_num_path = output_dir / 'token_num.csv'
    df_token_num = pd.DataFrame(token_num_list)
    df_token_num.to_csv(token_num_path, index=False)


if __name__ == '__main__':
    main()
