import json
import math
import argparse
import glob 

def combine_metric_files(metric_files_dir):
    a = [i for i in glob.glob(metric_files_dir + "*.json")]

    all_data = []
    for i in a:
        with open(i, "r") as f:
            ppl_values = json.load(f)
            all_data.extend(ppl_values)
    all_data = {i[id]: i['ls'] for i in all_data}
    return all_data 

def compute_learnability_scores(args):
    fine_tuned_ppls_dict = combine_metric_files(args.fine_tuned_vals)
    base_ppls_dict = combine_metric_files(args.pre_trained_vals)
    
    learnability_scores = []
    for id in fine_tuned_ppls_dict.keys():
        learnability_score = 1 - (math.log(fine_tuned_ppls_dict[id]) / math.log(base_ppls_dict[id]))
        learnability_scores.append({"id": id, "ls": learnability_score})

    with open(args.save_path, "w") as f:
        json.dump(sorted(learnability_scores, key=lambda x: x['ls'], reverse=True), f, indent=4)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--pre_trained_vals", type=str, required=True, 
                        help="Path of the directory containing pre-trained ppl vals")
    parser.add_argument("--fine_tuned_vals", type=str, required=True, 
                        help="Path of the directory containing fine-tuned ppl vals")

    parser.add_argument("--save_path", type=str, required=True)
    args = parser.parse_args()

    compute_learnability_scores(args)