import multiprocessing
import json
def bleu_worker(args):
    scorer,output_dict=args
    cur_pred = [output_dict["display_output"]]
    cur_reference = [output_dict["reference"]]
    cur_bleu = scorer.compute(
        predictions=cur_pred, references=cur_reference
    )
    cur_line = json.dumps(
        {
            "id": output_dict["id"],
            "reference_id": output_dict["reference_id"],
            "watermark_processor": output_dict["watermark_processor"],
            "bleu_score":cur_bleu['score']
        }
    )
    return cur_line+'\n'
    # cur_line+='\n'
    # with open(save_file_name,'a') as f:
    # f.write(cur_line)
    # f.flush()
    
    # return 1
    # return cur_line




def compute_bleu(output_path, bleu_save_path):
    from . import get_in_ds, get_in_ds_undetectable_exp

    # in_ds = get_in_ds()
    in_ds = get_in_ds_undetectable_exp()

    from datasets import load_dataset

    out_ds = load_dataset("json", data_files={"test": output_path})["test"]
    out_ds = out_ds.sort("id")

    wp_types = set(out_ds["watermark_processor"])

    s_out_dss = {}
    for wp_type in wp_types:
        s_out_ds = out_ds.filter(lambda x: x["watermark_processor"] == wp_type)
        assert len(s_out_ds) == len(in_ds)
        s_out_ds = s_out_ds.add_column("reference", in_ds["reference"])
        s_out_dss[wp_type] = s_out_ds

    import evaluate
    bleu_scorer = evaluate.load("sacrebleu")

    # import json
    import tqdm
    

    
    
    with open(bleu_save_path,'w') as f:
        for wp_type in wp_types:
            cur_len = len(s_out_dss[wp_type])
            # tasks=[]
            # lines=[]
            for idx in tqdm.tqdm(range(cur_len)):
                output_dict=s_out_dss[wp_type][idx]
                
                cur_pred = [output_dict["display_output"]]
                cur_reference = [output_dict["reference"]]
                cur_bleu = bleu_scorer.compute(
                    predictions=cur_pred, references=cur_reference
                )
                cur_line = json.dumps(
                    {
                        "id": output_dict["id"],
                        "reference_id": output_dict["reference_id"],
                        "watermark_processor": output_dict["watermark_processor"],
                        "bleu_score":cur_bleu['score']
                    }
                )
                cur_line+='\n'
                f.write(cur_line)
                
                

