import os
import ujson
import argparse
import numpy as np
from tqdm import tqdm

# Hard negatives filtering (M = 15)
NEGATIVE_INDICES=range(1, 15+1)

def load_jsonl(file_path: str):
    return [ujson.loads(line) for line in open(file_path, 'r')]
def save_jsonl(data: list[dict], file_path: str):
    print("Saving file")
    with open(file_path, 'w') as f:
        for item in tqdm(data):
            ujson.dump(item, f)
            f.write('\n')
    f.close()

def filter_dataset(dataset: list[dict]):
    ret = []
    for item in tqdm(dataset):
        positive_code_rank = item['positive_code_rank']
        
        # Consistency filtering (K = 20)
        if positive_code_rank > 20 or positive_code_rank < 0:
            continue

        negative_code_rank = item['negative_code_rank']
        negative_pass_idx = []
        for i, rank in enumerate(negative_code_rank):
            if rank in NEGATIVE_INDICES:
                negative_pass_idx.append(i)
        for k, v in item.items():
            if "negative_" in k:
                item[k] = np.array(v)[negative_pass_idx].tolist()
        ret.append(item)
    return ret

def main(args):
    if not os.path.exists(args.file_path):
        raise Exception(f"Path {args.file_path} doesn't exist. ")

    dataset = load_jsonl(args.file_path)
    filtered_dataset = filter_dataset(dataset)
    print(f"Size of filtered dataset: {len(filtered_dataset)}")

    print(f"Finished filtering all shard. Saving to {args.save_file_name}")
    save_jsonl(filtered_dataset, args.save_file_name)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--file_path', type=str, default="repo_contrastive_mined.jsonl")
    parser.add_argument('--save_file_name', type=str, default="repo_contrastive_mined_filtered.jsonl")

    args = parser.parse_args()
    main(args)