import os
import json
import tqdm
import torch
import re
from transformers import BertTokenizer, BertModel
from scipy.spatial.distance import cosine
import matplotlib.pyplot as plt
import argparse

tokenizer = BertTokenizer.from_pretrained('AnReu/math_pretrained_bert')
model = BertModel.from_pretrained('AnReu/math_pretrained_bert')

def check_labeled(file_path):
    k = 10
    with open(file_path, 'r') as f:  
        data = json.load(f)  
    labels = []
    # if 'largest_components' in data:
    #     return False
    for i in range(k):
        name = f"a_{i}"
        if 'label' in data[name] and 'syntax' in data[name]:
            labels.append(int(data[name]['label'])*int(data[name]['syntax']))   
    if len(labels) >= 10:
        return True
    else:
        return False

def get_json_files(root_dir):
    json_files = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                # if check_labeled(file_path):
                json_files.append(file_path)
    return json_files

def main(json_file_paths, suffix="gpt3.5"):  
    for (id, json_file) in tqdm.tqdm(enumerate(json_file_paths)):
    # for json_file in json_file_paths:
        # print(json_file)
        with open(json_file, 'r') as f:
            data = json.load(f)
        pred = f"prediction_{suffix}"
        if pred not in data or len(data[pred]) == 0: print(f"skip {json_file}"); continue
        components = [data[pred][key] for key in data[pred].keys()]
        if "informal_statement" in data:
            natural_problem = data['informal_statement']
            match = re.match(r'(.*) Show that it is (\d+)\.', natural_problem)
            if match:
                informal_statement = match.group(1)
                number = match.group(2)
                natural_problem = f"{informal_statement} The final Answer is ${number}$"
        else:
            natural_problem = data['natural problem']+" The final Answer is $"+data['natural answer']+ "$"
        #### autoformalize the problem
        
        origin_problem = natural_problem
        for i in range(10):
            name = f'a_{i}_{suffix}'
            informal_problem = data[name]['informal problem']
            for c in components:
                if i in c:
                    symbolic_score = len(c) / 10.0
                    break
            texts = [origin_problem, informal_problem]
            inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
            with torch.no_grad():
                embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
            semantic_score = 1 - cosine(embeddings[0], embeddings[1])
            data[name]["semantic_score"] = semantic_score
            data[name]["symbolic_score"] = symbolic_score
        with open(json_file, 'w') as f:
            json.dump(data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Process JSON files for a specific node.")
    parser.add_argument("--root_dir_list", default="../dataset/miniF2F/informal/", type=str, help="Comma-separated list of root directories")
    args = parser.parse_args()
    root_dir_list = args.root_dir_list.split(",")
    json_file_paths = []
    for path in root_dir_list:
        json_file_paths += get_json_files(path)
    print(f"Totally have {len(json_file_paths)} to do in {root_dir_list}")
    main(json_file_paths, suffix="gpt3.5")
    main(json_file_paths, suffix="deepseek")
