import json
import argparse
import glob 

def combine_metric_files(metric_files_dir):
    a = [i for i in glob.glob(metric_files_dir + "*.json")]

    all_data = []
    for i in a:
        with open(i, "r") as f:
            ppl_values = json.load(f)
            all_data.extend(ppl_values)
    return all_data 
    
def get_subset(args):
    with open(args.data_file, "r") as f:
    data = json.load(f)

    if args.metric == "ppl" or args.metric == "mir":
        if args.metric_files_dir is None:
            raise ValueError("If args.metric == 'ppl' or 'mir' then args.metric_files_dir cannot be None")

        combined_metric_data = combine_metric_files(args.metric_files_dir)
        metric_data = sorted(combined_metric_data, key=lambda x: x[args.metric], reverse=True)
    else:
        if args.metric_file is None:
            raise ValueError("If args.metric == 'learnability' then args.metric_file cannot be None")

        with open(args.metric_file, "r") as f:
            metric_data = json.load(f)

    # highest ppl 
    if args.metric == "ppl":    
        relevant_ids = [i['id'].split("::")[0] for i in metric_data[int(((100 - args.percent) * len(metric_data)) / 100):]]
        relevant_data = [i for i in data if i['id'] in relevant_ids]
    # highest mir
    elif args.metric == "mir":
        relevant_ids = [i['id'].split("::")[0] for i in metric_data[int(((100 - args.percent) * len(metric_data)) / 100):]]
        relevant_data = [i for i in data if i['id'] in relevant_ids]
    # highest learnability
    elif args.metric == "learnability":             
        relevant_ids = [i['id'].split("::")[0] for i in metric_data[int(((100 - args.percent) * len(metric_data)) / 100):]]
        relevant_data = [i for i in data if i['id'] in relevant_ids]    
    else:
        raise ValueError(f"{args.metric} metric is not implemented!")
    
    with open(args.save_path, "w") as f:
        json.dump(relevant_data, f, indent=4)

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--metric", type=str, default="ppl", choices=['ppl', 'mir', 'learnability'])

    parser.add_argument("--metric_file", type=str, default=None, required=True, 
            help='Path to the file containing per datapoint metrics')
    parser.add_argument("--metric_files_dir", type=str, default=None, required=True, 
            help='Path to the directiry containing per datapoint metrics files')
    parser.add_argument("--percent", type=int, default=None, required=True, 
            help='Percent of required data')
    parser.add_argument("--data_file", type=str, default=None, required=True, 
            help='Path to training data')
    parser.add_argument("--save_path", type=str, default=None, required=True)
    args = parser.parse_args()

    get_subset(args) 