import argparse
import json
from typing import List
from parse_tree import Parse_Tree
from tqdm import tqdm


def main(file1: str, file2: str) -> None:
    # Load the parse trees from the two files
    trees1 = load_trees_from_file(file1)
    trees2 = load_trees_from_file(file2)

    # Compute the tree edit distance between each pair of trees
    distances = []
    for i, (tree1, tree2) in enumerate(tqdm(zip(trees1, trees2), total=len(trees1))):
        distance = Parse_Tree.normalize_tree_kernel_distance(tree1, tree2, lambda_param=0.5)
        distances.append(distance)
    
    # Output the average edit distance
    avg_distance = sum(distances) / len(distances)
    print(f"The average tree kernel distacne between the two sets of trees is {avg_distance:.4f}")


def load_trees_from_file(filename: str) -> List[Parse_Tree]:
    # Load the list of dictionaries from the file
    with open(filename, 'r') as file:
        dict_list = json.load(file)

    # Convert each dictionary to a Parse_Tree object
    trees = []
    for dictionary in dict_list:
        tree = Parse_Tree.from_dict(dictionary['path_dict'])
        if "pred" in dictionary: 
            tree.name = dictionary["pred"]
        else: 
            tree.name = dictionary["target"]
        trees.append(tree)

    return trees


if __name__ == "__main__":
    # Set up argparse to load the file names as arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("file1", type=str, help="The name of the first file containing a list of dictionaries")
    parser.add_argument("file2", type=str, help="The name of the second file containing a list of dictionaries")
    args = parser.parse_args()

    # Call the main function with the file names as arguments
    main(args.file1, args.file2)