import argparse
import json
import pdb
import jsonlines
from fraction import Fraction
import re
import util
import sys
MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"

invalid_outputs = []

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        pass
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
    return False

def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None


def extract_answer_number(extract_ans):
    match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
    if match:
        if '/' in match.group():
            denominator = match.group().split('/')[1]
            numerator = match.group().split('/')[0]
            if is_number(denominator) == True and is_number(numerator) == True:
                if denominator == '0':
                    return round(float(numerator.replace(',', '')))
                else:
                    frac = Fraction(match.group().replace(',', ''))
                    num_numerator = frac.numerator
                    num_denominator = frac.denominator
                    return round(float(num_numerator / num_denominator))
            else:
                return None
        else:
            if float(match.group().replace(',', '')) == float('inf'):
                return None
            return round(float(match.group().replace(',', '')))
    else:
        return None

def process_results(doc, completion, answer):
    split_ans = completion.split('The answer is: ')
    if len(split_ans) > 1:
        ans = split_ans[-1]
        extract_ans_temp = ans.split('.\n')[0]
        extract_ans_temp = extract_ans_temp.strip()
        if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
            extract_ans = extract_ans_temp[0:-1]
        else:
            extract_ans = extract_ans_temp
        extract_ans = extract_ans.strip()
        if util.is_equiv(extract_ans, answer):
            return True
        else:
            return False
    else:
        temp = {'question': doc, 'output': completion, 'answer': answer}
        invalid_outputs.append(temp)
        return False

def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data

def judge(ans, answer):
    extract_ans_temp = ans.split('.\n')[0]
    extract_ans_temp = extract_ans_temp.strip()
    if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
        extract_ans = extract_ans_temp[0:-1]
    else:
        extract_ans = extract_ans_temp
    extract_ans = extract_ans.strip()
    if util.is_equiv(extract_ans, answer):
        return True
    else:
        return False

def test(input_file, ground_truth_file):
    input_data = []
    gt_data = []
    with open(input_file) as f:
        for line in f.readlines():
            input_data.append(json.loads(line))

    with open(ground_truth_file) as f:
        for line in f.readlines():
            gt_data.append(json.loads(line))

    result = []
    # pdb.set_trace()
    for u_id, (input_sample, ground_truth_sample) in enumerate(zip(input_data, gt_data)):
        try:
            candidate_answer = remove_boxed(util.last_boxed_only_string(input_sample['code'][0]))
        except:
            candidate_answer = None
        
        ground_answer = remove_boxed(util.last_boxed_only_string(ground_truth_sample['solution']))
        
        # pdb.set_trace()
        if candidate_answer != None:
            if judge(candidate_answer, ground_answer) == True:
                # print(candidate_answer, ground_answer)
                result.append(True)
            else:
                result.append(False)
            # result.append(float(candidate_answer) == float(ground_answer))
        else:
            result.append(False)
            temp = {'id': u_id, 'output': input_sample['code'][0], 'answer': ground_answer}
            invalid_outputs.append(temp)
    acc = sum(result) / len(result)
    print('len invalid outputs ====', len(invalid_outputs))
    print('gsm8k length====', len(result), ', gsm8k acc====', acc)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default='')  # model path
    parser.add_argument("--data_file", type=str, default='')  # data path
    parser.add_argument("--start", type=int, default=0) #start index
    parser.add_argument("--end", type=int, default=MAX_INT)  # end index
    parser.add_argument("--batch_size", type=int, default=400)  # batch_size
    parser.add_argument("--tensor_parallel_size", type=int, default=8)  # tensor_parallel_size
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    input_file = './MATH_test_Qwen2.5-Math-1.5B-Instruct_1.jsonl'
    ground_truth_file = 'MATH_500.jsonl'
    test(input_file, ground_truth_file)