import os
import re
from utils.utils import write_jsonl, read_jsonl
from utils.parser_utils import extract_answer

list_of_subs = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), (r'\ ', ''),
                (' ', ''), ('mbox', 'text'), (',\\text{and}', ','),
                ('\\text{and}', ','), ('\\text{m}', '\\text{}')]
list_of_words = [
    'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 'hours',
    'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes', 'digits',
    'cents', 'degrees', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges',
    'students', 'childrentickets', 'multiples','pages', 'slic', 'm', 'pieces',
    'slices'
]
# def format_solution(short_answer: str) -> str:
#     """Formats answer for uniformization purposes."""
#     short_answer = short_answer.strip()

#     # Perform basic substitutions using list_of_subs
#     for el1, el2 in list_of_subs:
#         short_answer = short_answer.replace(el1, el2)

#     # Remove any of the common words from the list_of_words
#     for el in list_of_words:
#         short_answer = short_answer.replace(el, '')

#     # Extract content between dollar signs
#     short_answer = re.sub(r'(.*?)(\$)(.*?)(\$)(.*)', '$\\3$', short_answer)

#     # Remove \text{} and \textbf{} but keep the content
#     short_answer = re.sub(r'(\\text\{)(.*?)(\})', '\\2', short_answer)
#     short_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', short_answer)

#     # Extract content inside \boxed{}
#     short_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', short_answer)

#     # Reformat \fracab as \frac{a}{b}
#     short_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', short_answer)

#     # Reformat \sqrta as \sqrt{a}
#     short_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', short_answer)

#     # Keep only the part after the last equals sign
#     short_answer = short_answer.split('=')[-1]

#     # Remove remaining dollar signs
#     short_answer = short_answer.replace('$', '')

#     # Remove commas from numbers like "100,000" -> "100000"
#     if short_answer.replace(',', '').isdigit():
#         short_answer = short_answer.replace(',', '')

#     # Strip unnecessary punctuation like colons or periods
#     short_answer = short_answer.strip(':.')
    
#     return short_answer



def process_math(is_train=False):
    """
    Read the math test data and split it by type into JSONL files.
    """
    if is_train:
        test_file = 'data/math/train.jsonl'
        save_dir = 'data/train'
    else:
        test_file = 'data/math/test.jsonl'
        save_dir = 'data/test'
    with open(test_file, 'r', encoding='utf-8') as f:
        data = read_jsonl(test_file)

        for problem_type in set([problem['type'] for problem in data]):
            problems = [problem for problem in data if problem['type'] == problem_type]
            # change the key names
            idx=0
            empty_correct_answer_count = 0
            for problem in problems:
                problem['question'] = problem['problem']
                problem['answer'] = problem['solution']
                del problem['problem']
                del problem['solution']
                problem['idx'] = idx
                # extract the ground truth
                # match = re.search(r"\\boxed\{([^\{\}]+|\{[^\}]+\})\}", problem['answer'])
                # if match:
                #     ground_truth = format_solution(match.group(1))
                #     problem['ground_truth'] = ground_truth
                ground_truth = extract_answer(problem['answer'])
                if not is_train:
                    if problem_type == 'prealgebra' and idx == 551:
                        ground_truth = '90'
                    if problem_type == 'Number Theory' and idx == 483:
                        ground_truth = 'Saturday'
                # solve the special case
                else:
                    if problem_type == 'Number Theory' and idx == 661:
                        ground_truth = '0'
                    if problem_type == 'Number Theory' and idx == 663:
                        ground_truth = '0'
                    if problem_type == 'Precalculus' and idx == 547:
                        ground_truth = '\\begin{pmatrix} 0 & 3 \\\\ 0 & -1 \\end{pmatrix}'
                    if problem_type == 'Prealgebra' and idx == 1078:
                        ground_truth = '4'

                if ground_truth != '':
                    problem['ground_truth'] = ground_truth
                else:
                    empty_correct_answer_count += 1
                    print(problem['idx'])
                    print(problem['type'])
                    print(problem['answer'])
                    problem['ground_truth'] = ''
                idx += 1

            print('empty correct answer count {0} out of {1} in {2} subset'.format(empty_correct_answer_count, len(problems), problem_type))
            write_jsonl(problems, os.path.join(save_dir, f'{problem_type.lower()}.jsonl'))

def process_gsm8k():
    """
    Read the GSM 8k test data and split it by type into JSONL files.
    """
    test_file = 'data/gsm8k.jsonl'
    save_dir = 'data'
    with open(test_file, 'r', encoding='utf-8') as f:
        
        data = read_jsonl(test_file)
        idx = 0
        empty_correct_answer_count = 0

        for problem in data:
            answer = problem['answer']
            ground_truth = extract_answer(answer)
            if ground_truth:
                problem['ground_truth'] = ground_truth
            else:
                print(idx)
                print(problem['type'])
                print(problem['answer'])
                problem['ground_truth'] = ''
            idx += 1
        print('empty correct answer count {0} out of {1} in GSM 8k subset'.format(empty_correct_answer_count, len(data)))

        write_jsonl(data, os.path.join(save_dir, 'gsm8k.jsonl'))


def reprocess_llama3_8b_train_results():
    method_list = ['cot', 'pal', 'codenl', 'nlcode']
    dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    for method in method_list:
        for dataset in dataset_list:
            file = f'results/{model_name}_{dataset}_Level 5_{method}_train.jsonl'
            data = read_jsonl(file)
            for row in data:
                answer = row['solution']
                ground_truth = extract_answer(answer)
                idx = row['idx']
                problem_type = row['type']

                if problem_type == 'Number Theory' and idx == 661:
                    ground_truth = '0'
                if problem_type == 'Number Theory' and idx == 663:
                    ground_truth = '0'
                if problem_type == 'Precalculus' and idx == 547:
                    ground_truth = '\\begin{pmatrix} 0 & 3 \\\\ 0 & -1 \\end{pmatrix}'
                if problem_type == 'Prealgebra' and idx == 1078:
                    ground_truth = '4'

                row['ground_truth'] = ground_truth
                row['answer'] = answer
                del row['solution']
                
            write_jsonl(data, file)
    
        
if __name__ == '__main__':
    is_train = True
    # process_math(is_train)
    # process_gsm8k()
    reprocess_llama3_8b_train_results()