import argparse
import numpy as np
import json
import os

def main(args):
    format_type = None
    if "-nl-" in args.file:
        format_type = "nl"
    elif "-no-format-" in args.file:
        format_type = "no-format"
    elif "-json-" in args.file:
        format_type = "json"
    corrects = []
    formats = []
    answer_parsed = []
    with open(args.file, 'r') as f:
        lines = f.readlines()
        report_lines = []
        for line in lines:
            line = json.loads(line)
            new_line = check_answer(line, format_type)
            report_lines.append(new_line)
            corrects.append(new_line["correct"])
            formats.append(new_line["format"])
            answer_parsed.append((new_line["answer"] is not None) and new_line["answer"] != "")
    
    with open(args.file.replace('.jsonl', '_report.jsonl'), 'w') as f:
        f.write(json.dumps({"Accuracy": np.mean(corrects), "Well-formatted": np.mean(formats), "Answer-Parsed": np.mean(answer_parsed)}) + "\n")
        for line in report_lines:
            f.write(json.dumps(line) + "\n")
    
    well_formatted = np.mean(formats)
    print(f"Well-formatted: {well_formatted:.4f}")
    answer_parsed = np.mean(answer_parsed)
    print(f"Answer-parsed: {answer_parsed:.4f}")
    acc = np.mean(corrects)
    print(f"Accuracy: {acc:.4f}")
            
            
def check_answer(line, format_type):
    # check format
    if format_type == "json":
        format, answer_str = check_json_answer(line)
    elif format_type == "nl":
        format, answer_str = check_nl_answer(line)
    elif format_type == "no-format":
        if ("-structured-" in args.file) or ("-decog-" in args.file) or ("0shot" not in args.file):
            format, answer_str = check_structured_answer(line)
        else:
            format, answer_str = check_unstructured_answer(line)
    correct = line["label"] == answer_str
    return {"format": format, "correct": correct, "label": line["label"], "answer": answer_str}

def check_nl_answer(line):
    output = line["output"]
    keys = ["he final answer is", "he answer is"]
    for key in keys:
        if key in output:
            return True, parse_and_clean(line["output"][line["output"].find(key) + len(key):])
    return False, None

def check_structured_answer(line):
    output = line["output"]
    keys = ["he final answer is", "he answer is"]   # comply with The and the
    for key in keys:
        if key in output:
            return True, parse_and_clean(line["output"][line["output"].find(key) + len(key):])
    return False, None

# check upper bound
def check_unstructured_answer(line):
    label = line["label"]
    last_sentence = line["output"].split('\n')[-1]
    if label in last_sentence:
        return True, label
    # add comma for every 3 digits
    for pos in range(3, 999, 4):
        if pos >= len(label):
            break
        label = label[:-pos] + "," + label[-pos:]
    if label in last_sentence:
        return True, line["label"]
    return True, line["output"].split('\n')[-1]
    
def check_json_answer(line):
    output = line["output"]
    # check json validity
    if '```' in output:
        output = output.replace('```json','```').split('```')[1]
    try:
        parsed_json = json.loads(output)
    except json.JSONDecodeError:
        return False, None
    if ("reason" not in parsed_json and "step_by_step_reasoning" not in parsed_json) or "answer" not in parsed_json:
        return False, None
    return True, str(parsed_json["answer"])

def parse_and_clean(input_string):
    # if model outputs equation, jump to right hand side
    input_string = input_string[input_string.find(':')+len(':'):]
    input_string = input_string[input_string.rfind('=')+len('='):]
    for idx, char in enumerate(input_string):
        if char.isdigit():
            input_string = input_string[idx:]
            break
    valid_chars = set("0123456789,'\" $")
    result = []
    for char in input_string:
        if char in valid_chars:
            result.append(char)
        else:
            break
    parsed_string = ''.join(result)
    cleaned_string = parsed_string.replace('"', '').replace("'", '').replace(' ', '').replace(',', '').replace('$', '')
    return cleaned_string

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', type=str, default="")
    parser.set_defaults(feature=True)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    main(args)