def get_wps(reweight_type,model_str):
    from new_unbiased_watermark import (
        Delta_Reweight,
        Gamma_Reweight,
        DeltaGumbel_Reweight,
        WatermarkLogitsProcessor,
        PrevN_ContextCodeExtractor,
        Beta_Reweight,
        # WatermarkLogitsProcessor_Kuditipudi,
        ExpMinSampling_Reweight,
        InverseSampling_Reweight,
        WatermarkLogitsProcessor_Baseline,
        NGramHashing,
        TokenSkipping,
        FixedKeySet,
        PositionHashing,
        KeySequence,
        NoKey,
        Dip_Reweight,
        Split_Reweight,
        Tri_Reweight,
        N_Reweight,
        GumbelMax_Reweight,
        STA_Reweight,
        Unigram_Reweight,
        EXP_edit_Reweight,
        ITS_edit_Reweight,
        WatermarkLogitsProcessor_Kuditipudi_OriImplement,
        SynthID_Text_Reweight
    )
    
    from ..lm_watermarking.watermark_processor import (
        WatermarkLogitsProcessor as WatermarkLogitsProcessor_John,
    )

    import random
    import copy

    key_set_size = 512
    skip_budget = 10
    random.seed(42)
    private_key = random.getrandbits(1024).to_bytes(128, "big")
    
    watermark_key_list = [
        NGramHashing(PrevN_ContextCodeExtractor(2), ignore_history=False)
    ]
    # reweight_list=[Split_Reweight(0)]
    
    
    # reweight_list=[STA_Reweight(gamma=0.5),
    #                GumbelMax_Reweight(is_baseline=True),
    #                GumbelMax_Reweight(is_baseline=False),
    #                Dip_Reweight(alpha=0.5),
    #                Dip_Reweight(alpha=0.4),
    #                Dip_Reweight(alpha=0.3),
    #                Unigram_Reweight(delta=0.5,gamma=0.5),
    #                Unigram_Reweight(delta=1.0,gamma=0.5),
    #                Unigram_Reweight(delta=1.5,gamma=0.5),
    #                Unigram_Reweight(delta=2.0,gamma=0.5)]
    # reweight_list=[N_Reweight(100)]
    
    # reweight_list=[SynthID_Text_Reweight(m=30)]
    
    
    # n_list={'meta-llama/Llama-2-7b-chat-hf':[2, 3, 4, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 4000, 8000, 16000, 32000],
    #         'meta-llama/Llama-3.2-3B-Instruct':[2, 3, 4, 5, 10, 20, 50, 100, 167, 501, 1002, 2004, 4008, 8016, 16032, 32064, 64128, 128256],
    #         'mistralai/Mistral-7B-Instruct-v0.3':[2, 3, 4, 5, 8, 10, 16, 20, 32, 50, 64, 100, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768],
    #         'microsoft/Phi-3.5-mini-instruct':[2, 3, 4, 5, 10, 20, 50, 100, 167, 501, 1002, 2004, 4008, 8016, 16032, 32064],
    #         'deepseek-ai/deepseek-llm-7b-chat':[2, 3, 4, 5, 10, 20, 50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400],
    #         'Qwen/Qwen2.5-3B-Instruct':[2, 3, 4, 5, 10, 20, 50, 100, 250, 500, 1187, 2374, 4748, 9496, 18992, 37984, 75968, 151936]}
    n_list={'meta-llama/Llama-2-7b-chat-hf':[10, 20, 50, 100],
            'meta-llama/Llama-3.2-3B-Instruct':[10, 20, 50, 100],
            'mistralai/Mistral-7B-Instruct-v0.3':[10, 20, 50, 100],
            'microsoft/Phi-3.5-mini-instruct':[10, 20, 50, 100],
            'deepseek-ai/deepseek-llm-7b-chat':[10, 20, 50, 100],
            'Qwen/Qwen2.5-3B-Instruct':[10, 20, 50, 100]}
    
    # model_str='meta-llama/Llama-2-7b-chat-hf' #32000  2 3 4 5 10 20 50 100 200 500 1000 2000 4000 8000 16000 32000 # 16
    # model_str='meta-llama/Llama-3.2-3B-Instruct' #128000 2 3 4 5 10 20 50 100 200 500 1000 2000 4000 8000 16000 32000 64000 128000 # 18

    # model_str='mistralai/Mistral-7B-Instruct-v0.3' #32768 2 3 4 5 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 #17
    
    # microsoft/Phi-3.5-mini-instruct ?32064 good enough # 2 3 4 5 10 20 50 167 501 1002 2004 4008 8016 16032 32064 #15
    # deepseek-ai/deepseek-llm-7b-chat ? 102400   2 3 4 5 10 20 50 100 200 400 800 1600 3200 6400 12800 25600 51200 102400  #18
    
    reweight_list_stanford=[]
    reweight_list=[]
    if reweight_type=='ITS':
        reweight_list_stanford=[ITS_edit_Reweight()]
    elif reweight_type=='EXP':
        reweight_list_stanford=[EXP_edit_Reweight()]
    elif reweight_type=='baselines':
        reweight_list=[Unigram_Reweight(delta=0.5,gamma=0.5),
                        Unigram_Reweight(delta=1.0,gamma=0.5),
                        Unigram_Reweight(delta=1.5,gamma=0.5),
                        Unigram_Reweight(delta=2.0,gamma=0.5),
                        Dip_Reweight(alpha=0.5),
                        Dip_Reweight(alpha=0.4),
                        Dip_Reweight(alpha=0.3),
                        STA_Reweight(gamma=0.5)]
    elif reweight_type=='main_exp':
        reweight_list=[Unigram_Reweight(delta=0.5,gamma=0.5),
                        Unigram_Reweight(delta=1.0,gamma=0.5),
                        Unigram_Reweight(delta=1.5,gamma=0.5),
                        Unigram_Reweight(delta=2.0,gamma=0.5),
                        Dip_Reweight(alpha=0.5),
                        Dip_Reweight(alpha=0.4),
                        Dip_Reweight(alpha=0.3),
                        STA_Reweight(gamma=0.5)]
        reweight_list+=[N_Reweight(n) for n in n_list[model_str]]
    elif reweight_type=='nmark':
        reweight_list=[N_Reweight(n) for n in n_list[model_str]]
        
    elif reweight_type=='test':
        reweight_list=[N_Reweight(20)]
        # raise NotImplementedError
    elif reweight_type=='GumbelMax':
        reweight_list=[GumbelMax_Reweight(is_baseline=True),GumbelMax_Reweight(is_baseline=False)]
    elif reweight_type=='SynthID':
        reweight_list=[SynthID_Text_Reweight(m=30)]
    else:
        print('Unknown reweight_type: ',reweight_type)
        
        
        
    

    wm_wps = []
    
    '''
    EXP_edit and ITS_edit
    '''
    for reweight in reweight_list_stanford:
        wm_wps.append(
            WatermarkLogitsProcessor_Kuditipudi_OriImplement(
                key_set_size=key_set_size,
                reweight=reweight)
            )
    
    
    '''
    Commmon WatermarkLogitsProcessor
    '''
    for wm_key in watermark_key_list:
        for reweight in reweight_list:
            wm_wps.append(
                WatermarkLogitsProcessor(
                    private_key,
                    reweight=copy.deepcopy(reweight),
                    watermark_key_list=[copy.deepcopy(wm_key)],
                )
            )
            
    john_wps = [
            WatermarkLogitsProcessor_John(
                vocab_size=0,  # placeholder
                gamma=0.5,
                delta=delta,
                seeding_scheme="simple_1",
            )
            for delta in [0.5, 1.0, 1.5, 2.0]
        ]


    # baseline_wp = WatermarkLogitsProcessor_Baseline()  # no watermark baseline    
    # return [None,baseline_wp,*wm_wps]
    # return [None]
    
    if reweight_type=='baselines' or reweight_type=='main_exp':
        return [*wm_wps,*john_wps]
    if reweight_type=='None':
        return [None]
    else:
        return [*wm_wps]


def get_num_gpus():
    import torch

    num_gpus = torch.cuda.device_count()
    return num_gpus


def batched_wp_task_worker(tq, get_in_ds, reweight_type,dataset_name,model_str,batch_size=8):
    ds = get_in_ds(dataset_name=dataset_name)

    from .common import get_wps

    wps = get_wps(reweight_type=reweight_type,model_str=model_str)

    from tqdm import tqdm

    # cnt=0
    for batch in tqdm(ds.iter(batch_size=batch_size), total=len(ds) // batch_size):
        for wp in wps:
            tq.put({"batch": batch, "watermark_processor": wp})
            # cnt+=1
            # print(f'batch {cnt} successfully added!',flush=True)
            # print(f'batch: {batch}',flush=True)
            # print(f'wp: {wp}',flush=True)

def merged_task_worker(
    get_in_ds,
    output_filepath,
    tq,
    batch_size=8,
    watermark_only=False,
    wh_only=False,
    no_gumbel=False,
    beta_only=False,
    dataset_name=None,
    use_other_wp=False,
    reweight_type=None,
    model_str=None,
):
    in_ds = get_in_ds(dataset_name=dataset_name)

    from datasets import load_dataset

    out_ds = load_dataset("json", data_files={"test": output_filepath})["test"]
    out_ds = out_ds.sort("id")

    dss, wps = add_reference(in_ds, out_ds)

    from tqdm import tqdm

    if use_other_wp:
        from .common import get_wps
        wps = [repr(wp) for wp in get_wps(reweight_type=reweight_type, model_str=model_str)]
        for ds in dss:  # only inplemented for None watermark, check if using generated results for multiple wp
            for wp in wps:
                ds = ds.map(lambda x: {"watermark_processor": wp})
                for batch in tqdm(ds.iter(batch_size=batch_size), total=len(ds) // batch_size):
                    # print(batch)
                    tq.put(batch)

    else:
        for ds, wp_str in zip(dss, wps):
            if watermark_only:
                if "John" in wp_str or "None" == wp_str:
                    continue
            if wh_only:
                if ", True)" in wp_str:
                    continue
            if no_gumbel:
                if "Gumbel" in wp_str:
                    continue
            if beta_only:
                if "Beta_Reweight" not in wp_str:
                    continue
            for batch in tqdm(ds.iter(batch_size=batch_size), total=len(ds) // batch_size):
                tq.put(batch)


def log(line: dict, f):
    import json

    f.write(json.dumps(line))
    f.write("\n")
    f.flush()


def simple_store_worker(path, rq, rqe):
    import os

    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))

    from queue import Empty

    with open(path, "w") as f:
        while not (rqe.is_set() and rq.empty()):
            try:
                result = rq.get(timeout=1)
            except Empty as e:
                continue
            assert isinstance(result, dict)
            if result == {}:
                continue
            if isinstance(next(iter(result.values())), list):
                assert all([isinstance(v, list) for v in result.values()])
                lens = [len(v) for v in result.values()]
                assert all([l == lens[0] for l in lens])
                for i in range(lens[0]):
                    log({k: v[i] for k, v in result.items()}, f)
            else:
                log(result, f)


def group_batch(batch):
    return {k: [v] for k, v in batch.items()}


from typing import Union


def tokenize_batch(
    example,
    tokenizer,
    fields=["input"],
    max_length: Union[int, dict] = 512,
    task_template: Union[str, dict] = {},
    padding_side={},
):
    #FIXME: repair the task template
    if isinstance(task_template, str):
        #  like "{input}"
        task_template = {"input": task_template}
    result = {}

    if tokenizer.name_or_path == "facebook/mbart-large-en-ro":
        tokenizer.tgt_lang = "ro_RO"
    if tokenizer.name_or_path == "daryl149/llama-2-7b-chat-hf":
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.name_or_path == "meta-llama/Llama-2-7b-chat-hf":
        tokenizer.pad_token = tokenizer.eos_token
    if "Llama-2-7b-chat-hf" in tokenizer.name_or_path:
        tokenizer.pad_token = tokenizer.eos_token
    if "Llama-3.2-3B-Instruct" in tokenizer.name_or_path:
        tokenizer.pad_token = tokenizer.eos_token
    if "Mistral-7B-Instruct-v0.3" in tokenizer.name_or_path:
        tokenizer.pad_token = tokenizer.eos_token
    # if "Qwen/Qwen2.5-3B-Instruct" in tokenizer.name_or_path:
        
        
    
        
        
    # if 'llama-2-7b-chat-hf' in tokenizer.name_or_path:
        
        
    # print(tokenizer.name_or_path)
    # exit(0)

    for field in fields:
        if field in example:
            kwargs = {}
            if isinstance(max_length, dict):
                kwargs["max_length"] = max_length[field]
            else:
                kwargs["max_length"] = max_length
            if field in task_template:
                texts = [
                    task_template[field].format(**{field: s}) for s in example[field]
                ]
            else:
                texts = example[field]
            if field in ["output", "reference"]:
                kwargs["text_target"] = texts
            else:
                kwargs["text"] = texts
            if field == "output":
                kwargs["add_special_tokens"] = False
            if field in padding_side:
                kwargs["padding_side"] = padding_side[field]
            
            if field=='input':
                if "Llama-3.2-3B-Instruct" in tokenizer.name_or_path \
                    or "Mistral-7B-Instruct-v0.3" in tokenizer.name_or_path \
                    or 'Qwen/Qwen2.5-3B-Instruct' in tokenizer.name_or_path:
                    new_input=[]
                    for cur_input in kwargs['text']:
                        messages = [{"role":"user","content":cur_input}]
                        warpped_input=tokenizer.apply_chat_template(messages,add_generation_prompt=True,tokenize=False)
                        new_input.append(warpped_input)
                    kwargs['text']=new_input

                # if :
                #     messages = [{"role":"user","content":cur_input}]
                #     cur_prompt_ids=tokenizer.apply_chat_template(messages,add_generation_prompt=True,return_tensors='pt').cuda()
                #     pass

            # print(kwargs)
            # exit(0)
            
            result[field] = tokenizer(
                **kwargs,
                truncation=True,
                padding=True,
                return_tensors="pt",
            )

    return result


def set_spawn():
    import torch.multiprocessing as mp

    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass


def remove_tailing_pad_s(s: str):
    while s.endswith("<pad>"):
        s = s[:-5]
    return s
    #  index = s.find("<pad>")
    #  if index == -1:
    #      return s
    #  else:
    #      return s[:index]


def remove_tailing_pad(strs: list[str]):
    return [remove_tailing_pad_s(s) for s in strs]


def transformer_worker(
    tq,
    tqe,
    rq,
    gpu_id,
    model_str,
    generation_kwargs={},
    decoder_only=False,
    tokenization_kwargs={},
    access_token="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",
):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import AutoModelForCausalLM
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    from new_unbiased_watermark import patch_model

    print('start loading model...',flush=True)
    if decoder_only:
        model = AutoModelForCausalLM.from_pretrained(model_str, token=access_token).to(f"cuda:{gpu_id}")
        #Zaratan: 
        # model = AutoModelForCausalLM.from_pretrained(model_str).to(f"cuda:{gpu_id}")
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_str, token=access_token).to(f"cuda:{gpu_id}")
        #Zaratan: 
        # model = AutoModelForCausalLM.from_pretrained(model_str).to(f"cuda:{gpu_id}")
    patch_model(model)
    tokenizer = AutoTokenizer.from_pretrained(model_str, token=access_token)
    # tokenizer.pad_token = tokenizer.eos_token

    from queue import Empty

    model.eval()

    import time
    with torch.no_grad():
        while not (tqe.is_set() and tq.empty()):
            try:
                task = tq.get(timeout=1)
            except Empty as e:
                continue
            batch = task["batch"]
            tbatch = tokenize_batch(batch, tokenizer, **tokenization_kwargs)
            wp = task["watermark_processor"]
            lps = []
            if wp is not None:
                if "reset_watermark_key" in dir(wp):
                    batch_size = len(batch["id"])
                    wp.reset_watermark_key(batch_size)
                if "vocab_size" in dir(wp):
                    wp.vocab_size = model.config.vocab_size

                lps.append(wp)

            # for reproducibility and sufficient randomness
            import hashlib

            hash = hashlib.sha256()
            # hash.update(str(batch["id"]).encode("utf-8"))
            hash.update((str(batch["id"]) + repr(wp)).encode("utf-8"))
            seed = hash.digest()
            seed = int.from_bytes(seed, "big") % (2**32 - 1)

            set_seed(seed)

            outputs_ids = model.generate(
                tbatch["input"]["input_ids"].to(device=model.device),
                attention_mask=tbatch["input"]["attention_mask"].to(
                    device=model.device
                ),
                # temperature=1,
                do_sample=True,
                num_beams=1,
                top_k=0,  # default
                top_p=1,
                # length_penalty=0,
                # early_stopping=False,
                logits_warper=LogitsProcessorList(lps),
                **generation_kwargs,
            )
            # outputs_ids = model.generate(
            #     tbatch["input"]["input_ids"].to(device=model.device),
            #     attention_mask=tbatch["input"]["attention_mask"].to(
            #         device=model.device
            #     ),
            #     do_sample=True,
            #     num_beams=1,
            #     top_k=0,  # default
            #     # length_penalty=0,
            #     # early_stopping=False,
            #     logits_warper=LogitsProcessorList(lps),
            #     **generation_kwargs,
            # )

            if decoder_only:
                outputs_ids = outputs_ids[:, tbatch["input"]["input_ids"].shape[1] :]
            outputs = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
            outputs = remove_tailing_pad(outputs)
            display_outputs = tokenizer.batch_decode(
                outputs_ids, skip_special_tokens=True
            )
            wp_str = repr(wp)
            rq.put(
                {
                    "output": outputs,
                    "display_output": display_outputs,
                    "output_ids": outputs_ids.tolist(),
                    "id": batch["id"],
                    "reference_id": batch["reference_id"],
                    "watermark_processor": [wp_str] * len(outputs),
                }
            )


def add_reference(in_ds, out_ds):
    """assuming ordered by ids"""
    wp_types = set(out_ds["watermark_processor"])

    s_out_dss = []
    for wp_type in wp_types:
        s_out_ds = out_ds.filter(lambda x: x["watermark_processor"] == wp_type)
        # assert len(s_out_ds) == len(in_ds)
        # s_out_ds = s_out_ds.add_column("input", in_ds["input"])
        # s_out_ds = s_out_ds.add_column("reference", in_ds["reference"])
        s_out_dss.append(s_out_ds)
    from datasets import concatenate_datasets

    return s_out_dss, wp_types


def bertscore_worker(tq, tqe, rq, gpu_id=0):
    import bert_score

    scorer = bert_score.BERTScorer(
        lang="de",
        rescale_with_baseline=True,
        device=f"cuda:{gpu_id}",
        use_fast_tokenizer=True,
    )

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        (P, R, F) = scorer.score(batch["display_output"], batch["reference"])
        rq.put(
            {
                **batch,
                "bertscore.precision": P.tolist(),
                "bertscore.recall": R.tolist(),
                "bertscore.f1": F.tolist(),
            }
        )


def rouge_worker(tq, tqe, rq):
    import evaluate

    rouge = evaluate.load("rouge")

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        rouge_scores = rouge.compute(
            predictions=batch["display_output"],
            references=batch["reference"],
            rouge_types=["rouge1", "rouge2", "rougeL"],
            use_stemmer=True,
            use_aggregator=False,
        )
        rq.put({**rouge_scores, **batch})


def remove_text_worker(tq, tqe, rq):
    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        for f in ["input", "output", "reference", "display_output"]:
            if f in batch:
                del batch[f]
        rq.put(batch)


import torch


@torch.no_grad()
def get_ppl(model, tbatch, decoder_only=False):
    if not decoder_only:
        input_ids = tbatch["input"]["input_ids"].to(model.device)
        attention_mask = tbatch["input"]["attention_mask"].to(model.device)
        decoder_input_ids = tbatch["output"]["input_ids"][..., :-1].to(model.device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
        )
        #  labels: [batch_size, sequence_length]
        labels = tbatch["output"]["input_ids"][..., 1:].to(model.device)
        label_attention_mask = tbatch["output"]["attention_mask"][..., 1:].to(
            model.device
        )
        #  output.logits: [batch_size, sequence_length, vocab_size]
        logits = outputs.logits
    else:
        com_input_ids = torch.cat(
            [tbatch["input"]["input_ids"], tbatch["output"]["input_ids"][..., :-1]],
            dim=-1,
        ).to(model.device)
        com_attention_mask = torch.cat(
            [
                tbatch["input"]["attention_mask"],
                tbatch["output"]["attention_mask"][..., :-1],
            ],
            dim=-1,
        ).to(model.device)
        outputs = model(input_ids=com_input_ids, attention_mask=com_attention_mask)
        labels = tbatch["output"]["input_ids"].to(model.device)
        label_attention_mask = tbatch["output"]["attention_mask"].to(model.device)
        logits = outputs.logits[:, tbatch["input"]["input_ids"].shape[1] - 1 :]

    from torch.nn import CrossEntropyLoss

    assert decoder_only
    assert logits.shape[0] == 1

    # if logits.shape[1]<500:
    #     return [-1]

    loss_fct = CrossEntropyLoss(reduction="none")

    shape = labels.shape
    #  loss: [batch_size, sequence_length]
    losses = loss_fct(
        logits.reshape(-1, logits.shape[-1]),
        labels.view(-1),
    ).reshape(shape)
    #  loss: [batch_size]
    losses = (losses * label_attention_mask.float()).sum(
        dim=-1
    ) / label_attention_mask.sum(dim=-1)
    ppl = torch.exp(losses).cpu().tolist()
    return ppl


def ppl_worker(
    tq,
    tqe,
    rq,
    gpu_id,
    oracle_model_str,
    decoder_only=False,
    tokenization_kwargs={},
    access_token="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",
):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import AutoModelForCausalLM
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    if decoder_only:
        model = AutoModelForCausalLM.from_pretrained(oracle_model_str, token=access_token).to(
            f"cuda:{gpu_id}"
        )
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(oracle_model_str, token=access_token).to(
            f"cuda:{gpu_id}"
        )
    tokenizer = AutoTokenizer.from_pretrained(oracle_model_str, token=access_token)

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        tbatch = tokenize_batch(
            batch,
            tokenizer,
            ["input", "output"],
            **tokenization_kwargs,
        )
        ppl = get_ppl(model, tbatch, decoder_only=decoder_only)
        with torch.cuda.device(model.device):
            torch.cuda.empty_cache()

        assert len(ppl) == 1
        # if ppl[0]>0:

        rq.put(
            {
                **batch,
                "ppl": ppl,
            }
        )


def random_paraphrase(input_ids, vocab_size, device, eps):
    import hashlib

    hash = hashlib.sha256()
    # hash.update(str(batch["id"]).encode("utf-8"))
    hash.update(str(input_ids).encode("utf-8"))
    seed = hash.digest()
    seed = int.from_bytes(seed, "big") % (2**32 - 1)
    torch.manual_seed(seed)

    modified_input_ids = torch.where(
        torch.rand(input_ids.shape).to(device) > eps,
        input_ids,
        torch.randint_like(input_ids, low=0, high=vocab_size).to(device),
    )

    return modified_input_ids


@torch.no_grad()
def get_sta_score_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0):
    assert eps <= 1
    assert eps >= 0

    decoder_input_ids = output_ids.to(device)
    label_attention_mask = torch.ones_like(output_ids).to(device)

    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        out = wp.get_sta_score(pre, vocab_size, cur_token)
        # try:
        scores[:, i+1] = torch.stack(out).reshape(-1)
    return scores, label_attention_mask


# calc p val for EXP_edit and ITS_edit
def get_p_val_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0,gamma=0):
    assert eps <= 1
    assert eps >= 0
    
    decoder_input_ids=output_ids.to(device)
    label_attention_mask=torch.ones_like(decoder_input_ids).to(device)
    
    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)
    
    p_val=wp.get_p_val(decoder_input_ids,vocab_size,gamma=gamma)
    
    return p_val, label_attention_mask


@torch.no_grad()
def get_unigram_score_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0):
    assert eps <= 1
    assert eps >= 0
    
    decoder_input_ids=output_ids.to(device)
    label_attention_mask=torch.ones_like(decoder_input_ids).to(device)
    
    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        out = wp.get_unigram_score(pre, vocab_size, cur_token)
        scores[:, i+1] = torch.stack(out).reshape(-1)

    return scores, label_attention_mask


@torch.no_grad()
def get_gumbelmax_score_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0):
    assert eps <= 1
    assert eps >= 0
   
    decoder_input_ids = output_ids.to(device)
    label_attention_mask=torch.ones_like(decoder_input_ids).to(device)
    
    
    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        out = wp.get_gumbelmax_score(pre, vocab_size, cur_token)
        scores[:, i+1] = torch.stack(out).reshape(-1)

    return scores, label_attention_mask


# for Dipmark and gamma-reweight
@torch.no_grad()
def get_quantile_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0):
    assert eps <= 1
    assert eps >= 0
    
    decoder_input_ids = output_ids.to(device)

    label_attention_mask = torch.ones_like(decoder_input_ids).to(device)

    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)

    quantile = torch.zeros(decoder_input_ids.shape, device=device)
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        out = wp.get_green_token_quantile(pre, vocab_size, cur_token)
        quantile[:, i+1] = torch.stack(out).reshape(-1)

    return quantile, label_attention_mask

# for Splitmark
@torch.no_grad()
def get_split_res_id(vocab_size, output_ids, wp, device, test_config={}, la_wp=None, eps=0,split_num=None):
    assert eps <= 1
    assert eps >= 0
    decoder_input_ids=output_ids.to(device)
    label_attention_mask=torch.ones_like(decoder_input_ids).to(device)

    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size,device, eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        out = wp.get_n_res(pre, vocab_size, cur_token,cur_n=split_num)
        scores[:, i+1] = torch.tensor(out).reshape(-1)
    return scores, label_attention_mask
    


from ..lm_watermarking.watermark_processor import (
    WatermarkLogitsProcessor as WatermarkLogitsProcessor_John,
)

@torch.no_grad()
def get_green_token_scores_id(
    vocab_size, output_ids, wp, device, eps=0
):
    assert eps <= 1
    assert eps >= 0
    
    decoder_input_ids = output_ids.to(device)
    label_attention_mask = torch.ones_like(decoder_input_ids)

    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size, device,eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)

    assert decoder_input_ids.shape[0] == 1
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[0, : i + 1]
        cur_token = decoder_input_ids[0, i + 1]
        assert wp.select_green_tokens
        green_token_ids=wp._get_greenlist_ids(pre)
        if cur_token in green_token_ids:
            scores[0,i+1]=1
    return scores, label_attention_mask



@torch.no_grad()
def get_synthid_text_scores_id(
    vocab_size, output_ids, wp, device, eps=0
):
    # raise NotImplementedError
    assert eps <= 1
    assert eps >= 0
    
    decoder_input_ids = output_ids.to(device)
    label_attention_mask = torch.ones_like(decoder_input_ids)

    if eps > 0:
        decoder_input_ids = random_paraphrase(decoder_input_ids, vocab_size, device,eps)

    scores = torch.zeros(decoder_input_ids.shape, device=device)

    assert decoder_input_ids.shape[0] == 1
    for i in range(decoder_input_ids.size(1) - 1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i + 1]
        cur_score=wp.get_synthid_text_res(pre, vocab_size, cur_token)
        scores[:,i+1]=cur_score
        
    return scores, label_attention_mask


def beta_score_worker(
    tq, tqe, rq, gpu_id, oracle_model_str, decoder_only=False, eps=0, tokenization_kwargs={},
    access_token="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",
):
    from transformers import (
        AutoModelForSeq2SeqLM,
        AutoTokenizer,
        set_seed,
        AutoModelForCausalLM,
    )
    from transformers import LogitsProcessorList, TemperatureLogitsWarper
    from queue import Empty

    if decoder_only:
        model = AutoModelForCausalLM.from_pretrained(oracle_model_str, token=access_token).to(
            f"cuda:{gpu_id}"
        )
        
        # model = AutoModelForCausalLM.from_pretrained(oracle_model_str).to(
        #     f"cuda:{gpu_id}"
        # )
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(oracle_model_str, token=access_token).to(
            f"cuda:{gpu_id}"
        )
        
        # model = AutoModelForSeq2SeqLM.from_pretrained(oracle_model_str).to(
        #     f"cuda:{gpu_id}"
        # )

    tokenizer = AutoTokenizer.from_pretrained(oracle_model_str, token=access_token)
    vocab_size = model.config.vocab_size
    del model
    # print('model.config.vocab_size:',vocab_size)
    # print('max(tokenizer.vocab.values()):',max(tokenizer.vocab.values()))
    # raise NotImplementedError

    device = f"cuda:{gpu_id}"
    # vocab_size = max(tokenizer.vocab.values())

    # def chernoff_bound(quantiles,lens):
    #     avg_score=torch.sum(quantiles>0.5,dim=-1)/lens
    #     bound=(((1-avg_score)**(avg_score-1))/(2*avg_score**avg_score))**lens
    #     return avg_score,bound

    def score_func(quantiles, lens, mode="scaled_sigmoid"):
        """
        params:
            quantiles: [batch_size, max_len]
        """
        if mode == "linear":
            return torch.sum(quantiles, dim=-1), lens / 2

        if mode == "test":
            return torch.sum(quantiles > 0.5, dim=-1), lens / 2

        if mode == "scaled_sigmoid":
            left = -10
            right = 10
            return (
                torch.sum(torch.sigmoid(quantiles * (right - left) + left), dim=-1),
                lens / 2,
            )

        if mode == "scaled_log":
            left = 1
            right = torch.e
            return (
                torch.sum(torch.log(quantiles * (right - left) + left), dim=-1),
                1 / (torch.e - 1) * lens,
            )

        return NotImplementedError

    from new_unbiased_watermark import (
        Gamma_Reweight,
        Delta_Reweight,
        DeltaGumbel_Reweight,
        WatermarkLogitsProcessor,
        WatermarkLogitsProcessor_Kuditipudi_OriImplement,
        PrevN_ContextCodeExtractor,
        Beta_Reweight,
        NoKey,
        PositionHashing,
        NGramHashing,
        KeySequence,
        TokenSkipping,
        Dip_Reweight,
        Split_Reweight,
        Tri_Reweight,
        N_Reweight,
        STA_Reweight,
        GumbelMax_Reweight,
        Unigram_Reweight,
        ITS_edit_Reweight,
        EXP_edit_Reweight,
        SynthID_Text_Reweight
    )
    from ..lm_watermarking.watermark_processor import (
        WatermarkLogitsProcessor as WatermarkLogitsProcessor_John,
    )

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        assert len(set(batch["watermark_processor"])) == 1

        wp_str = batch["watermark_processor"][0]
        la_wp = None
        tbatch = tokenize_batch(
            batch, tokenizer, ["output"], **tokenization_kwargs
        )
        
        if eps == 0:
            output_ids=torch.tensor(tbatch["output"]['input_ids'])
        else:
            output_ids=torch.tensor(batch['output_ids'])
        

        if 'N_Reweight' in wp_str:
            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True
            
            import re
            def extract_n_value(text):
                pattern = re.search(r"N_Reweight\(n=(\d+)\)", text)
                if pattern:
                    return int(pattern.group(1))
                return None
            
            cur_n=extract_n_value(wp_str)
            assert cur_n is not None
            scores, label_attention_mask = get_split_res_id(
                    vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps,split_num=cur_n
                )
            
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0
            
            scores = scores * label_attention_mask
            raw_scores = scores.sum(dim=-1)
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )
            
        elif "John" in wp_str:
            wp = eval(wp_str)

            if wp.vocab_size == 0:
                wp.vocab_size = vocab_size

            scores, label_attention_mask = get_green_token_scores_id(vocab_size,output_ids, wp, device, eps=eps)
            
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0 

            scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            raw_scores = scores.sum(dim=-1)
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )
        elif 'SynthID_Text' in wp_str:
            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True
            scores, label_attention_mask = get_synthid_text_scores_id(vocab_size,output_ids, wp, device, eps=eps)
            
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0 # for fair comparision

            scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            raw_scores = scores.sum(dim=-1)
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )
            
            
            
        elif ('ITS_edit' in wp_str) or ('EXP_edit' in wp_str):
            wp = eval(wp_str)

            if 'ITS_edit' in wp_str:
                gamma=0.4
            elif 'EXP_edit' in wp_str:
                gamma=0.0
            else:
                print('Unknown wp_str: ',wp_str)
                exit(1)
            
            p_val, label_attention_mask = get_p_val_id(
                vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps,gamma=gamma
            )
            assert label_attention_mask.shape[0] == 1
            # label_attention_mask[0, :5] = 0 # for fair comparision

            # scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "p_val": p_val.view(1).cpu().tolist()
                }
            )  
            
        elif 'Unigram' in wp_str:
            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True

            scores, label_attention_mask = get_unigram_score_id(
                vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps
            )
            
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0 # for fair comparision

            scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            raw_scores = scores.sum(dim=-1)
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )  
            
        elif "GumbelMax_Reweight" in wp_str:
            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True


            scores, label_attention_mask = get_gumbelmax_score_id(
                vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps
            )
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0 # for fair comparision

            scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            raw_scores = scores.sum(dim=-1)
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )
            
        elif "STA_Reweight" in wp_str:
            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True

            scores, label_attention_mask = get_sta_score_id(
                vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps
            )
            
            assert label_attention_mask.shape[0] == 1
            label_attention_mask[0, :2] = 0 # for fair comparision

            scores = scores * label_attention_mask
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)
            raw_scores = scores.sum(dim=-1)
            rq.put(
                {
                    **batch,
                    "lens": seq_len.cpu().tolist(),
                    "raw_scores": raw_scores.cpu().tolist()
                }
            )
            

        elif ("Beta" in wp_str) or ("Dip_" in wp_str):

            wp = eval(wp_str)
            wp.reset_watermark_key(len(batch["watermark_processor"]))
            wp.ignore_history = True
            
            quantiles, label_attention_mask = get_quantile_id(
                vocab_size, output_ids, wp, device, la_wp=la_wp, eps=eps
            )
            assert label_attention_mask.shape[0] == 1
            
            label_attention_mask[0, :2] = 0

            quantiles = quantiles * label_attention_mask
            cum_label_attention_mask = torch.cumsum(label_attention_mask, dim=-1)
            lens = cum_label_attention_mask[:, -1]

            raw_score, expected_value = score_func(quantiles, lens, mode="test")
            seq_len = torch.sum(label_attention_mask, dim=-1, keepdim=False)

            assert torch.all(seq_len == lens)
            final_score = (raw_score - expected_value) / torch.sqrt(seq_len)

            rq.put(
                {
                    **batch,
                    "lens": lens.cpu().tolist(),
                    "beta_score": final_score.cpu().tolist(),
                }
            )
            
        else:
            print(f"Unknown Watermark Processor: {wp_str}")
            raise NotImplementedError

