import json
import argparse
import logging
import os
import re
from tqdm import tqdm
from openai import OpenAI
from math_verify import parse, verify


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def read_json_fields(filename):
    try:
        with open(filename, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        logging.error("The file was not found.")
    except json.JSONDecodeError:
        logging.error("There was an error decoding the JSON file.")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def read_json_objects(filename, field_names=None):
    file_extension = os.path.splitext(filename)[1]
    if file_extension == '.jsonl':
        try:
            with open(filename, 'r') as file:
                lines = file.readlines()
            items = []
            for line in lines:
                item = json.loads(line)
                if field_names is not None and isinstance(field_names, list):
                    new_item = {}
                    for field_name in item:
                        new_item[field_name] = item[field_name]
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSONL file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    elif file_extension == '.json':
        try:
            with open(filename, 'r') as file:
                data = json.load(file)
            items = []
            for item in data:
                items.append(item)
            return items
        except FileNotFoundError:
            logging.error("The file was not found.")
        except json.JSONDecodeError:
            logging.error("There was an error decoding the JSON file.")
        except Exception as e:
            logging.error(f"An error occurred: {e}")
    else:
        logging.error(f"Unknown file extension {file_extension}")
        return []


def write_data_to_json_file(data, file_path):
    try:
        with open(file_path, 'w') as file:
            json.dump(data, file, ensure_ascii=False, indent=4)
        logging.info(f"Data successfully written to {file_path}")
    except Exception as e:
        logging.error(f"An error occurred: {e}")


def math_verify(llm_res, ground_truth):
    llm_answer = parse(llm_res)
    ground_truth_answer = parse(ground_truth)
    correct = verify(llm_answer, ground_truth_answer)
    
    # gt_dataset = ground_truth.split('####')[1].strip()
    # logging.debug(f"extracted: {llm_answer}, ext_gt: {ground_truth_answer}, dataset_gt: {gt_dataset}, correct: {correct}")
    
    return correct

        
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, required=True, help='path to the input json file')
    args = parser.parse_args()

    data_list = read_json_objects(args.input)
    
    total = 0
    correct = 0
    for item in data_list:
        if math_verify(item['output'], item['answer']):
            correct += 1
        total += 1
    logging.info(f"Total: {total}, correct: {correct}, Math-verify accuracy: {float(correct) / total * 100} %")
    

if __name__ == "__main__":
    main()