def get_wps(reweight_type):
    from new_unbiased_watermark import (
        Delta_Reweight,
        Gamma_Reweight,
        DeltaGumbel_Reweight,
        WatermarkLogitsProcessor,
        PrevN_ContextCodeExtractor,
        Beta_Reweight,
        # WatermarkLogitsProcessor_Kuditipudi,
        ExpMinSampling_Reweight,
        InverseSampling_Reweight,
        WatermarkLogitsProcessor_Baseline,
        NGramHashing,
        TokenSkipping,
        FixedKeySet,
        PositionHashing,
        KeySequence,
        NoKey,
        Dip_Reweight,
        Split_Reweight,
        N_Reweight,
        ITS_edit_Reweight,
        EXP_edit_Reweight,
        STA_Reweight,
        Unigram_Reweight,
        SynthID_Text_Reweight,
        WatermarkLogitsProcessor_Kuditipudi_OriImplement
    )

    import random
    import copy
    
    key_set_size = 512
    skip_budget=10
    # random.seed(42)
    random.seed(42)
    private_key = random.getrandbits(1024).to_bytes(128, "big")
    # watermark_key_list=[NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=False),NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=True),KeySequence(key_sequence_len=256),TokenSkipping(skip_budget),PositionHashing(),NoKey()]
    # reweight_list=[Beta_Reweight(),Gamma_Reweight(),Delta_Reweight(),ExpMinSampling_Reweight(),InverseSampling_Reweight()]
    
    # watermark_key_list=[NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=False),NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=True),KeySequence(key_sequence_len=256),TokenSkipping(skip_budget),PositionHashing(),NoKey()]
    watermark_key_list=[NGramHashing(PrevN_ContextCodeExtractor(2),ignore_history=False)]
    # reweight_list=[Beta_Reweight(0),Beta_Reweight(0.05),Beta_Reweight(0.1),Beta_Reweight(0.2),Beta_Reweight(0.3),Delta_Reweight(),ExpMinSampling_Reweight(),InverseSampling_Reweight()]
    # reweight_list=[Dip_Reweight(0.4),Dip_Reweight(0.3),Dip_Reweight(0.2),Dip_Reweight(0.1),Beta_Reweight(0),Beta_Reweight(0.05),Beta_Reweight(0.1),Beta_Reweight(0.2),Beta_Reweight(0.3),Beta_Reweight(0.4)]
    # reweight_list=[Beta_Reweight(0),Beta_Reweight(0.05),Beta_Reweight(0.1),Beta_Reweight(0.2),Beta_Reweight(0.3),Beta_Reweight(0.4)]
    # reweight_list=[Split_Reweight(0)]
    
    n_list = [10, 20, 50, 100]

    reweight_list=[Unigram_Reweight(delta=0.5),
                   Unigram_Reweight(delta=1.0),
                   Unigram_Reweight(delta=1.5),
                   Unigram_Reweight(delta=2.0),
                   STA_Reweight(gamma=0.5),
                   Dip_Reweight(alpha=0.5),
                   Dip_Reweight(alpha=0.4),
                   Dip_Reweight(alpha=0.3)]
    reweight_list+=[N_Reweight(n) for n in n_list]
    reweight_list+=[SynthID_Text_Reweight(m=30)]
    reweight_list_stanford=[ITS_edit_Reweight(), EXP_edit_Reweight()]
    # reweight_list=[N_Reweight(2),
    #                N_Reweight(3),
    #                N_Reweight(4),
    #                N_Reweight(5),
    #                N_Reweight(10),
    #                N_Reweight(20),
    #                N_Reweight(50),
    #                N_Reweight(100),
    #                N_Reweight(200),
    #                N_Reweight(500),
    #                N_Reweight(1000),
    #                N_Reweight(2000),
    #                N_Reweight(4000),
    #                N_Reweight(8000),
    #                N_Reweight(16000),
    #                N_Reweight(32000)]
    # watermark_key_list=[NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=False),NGramHashing(PrevN_ContextCodeExtractor(5),ignore_history=True),KeySequence(key_sequence_len=256),TokenSkipping(skip_budget),PositionHashing(),NoKey()]
    # reweight_list=[Beta_Reweight()]
    #DeltaGumble is equivalent to ExpMinSampling
    
    
    # delta_wp = WatermarkLogitsProcessor(
    #     private_key,
    #     Delta_Reweight(),
    #     [NGramHashing()]
    #     PrevN_ContextCodeExtractor(5),
    # )
    # gamma_wp = WatermarkLogitsProcessor(
    #     private_key,
    #     Gamma_Reweight(),
    #     PrevN_ContextCodeExtractor(5),
    # )
    wm_wps = []
    '''
    EXP_edit and ITS_edit
    '''
    for reweight in reweight_list_stanford:
        wm_wps.append(
            WatermarkLogitsProcessor_Kuditipudi_OriImplement(
                key_set_size=key_set_size,
                reweight=reweight)
            )
    
    
    
    '''
    Commmon WatermarkLogitsProcessor
    '''
    for wm_key in watermark_key_list:
        for reweight in reweight_list:
            wm_wps.append(
                WatermarkLogitsProcessor(
                    private_key,
                    reweight=copy.deepcopy(reweight),
                    watermark_key_list=[copy.deepcopy(wm_key)],
                )
            )
            
    from ..lm_watermarking.watermark_processor import (
        WatermarkLogitsProcessor as WatermarkLogitsProcessor_John,
    )

    john_wps = [
        WatermarkLogitsProcessor_John(
            vocab_size=0,  # placeholder
            gamma=0.5,
            delta=delta,
            seeding_scheme="simple_1",
        )
        for delta in [0.5, 1.0, 1.5, 2.0]
    ]
    
    baseline_wp=WatermarkLogitsProcessor_Baseline() #no watermark baseline
    # return [None,baseline_wp, *wm_wps]
    return [
        None,
        baseline_wp,
        *wm_wps,
        *john_wps
    ]
    # return [delta_wp_woh]
        
def get_wps_sep(reweight_type):
    from new_unbiased_watermark import (
        Delta_Reweight,
        Gamma_Reweight,
        DeltaGumbel_Reweight,
        WatermarkLogitsProcessor,
        PrevN_ContextCodeExtractor,
        Beta_Reweight,
        # WatermarkLogitsProcessor_Kuditipudi,
        ExpMinSampling_Reweight,
        InverseSampling_Reweight,
        WatermarkLogitsProcessor_Baseline,
        NGramHashing,
        TokenSkipping,
        FixedKeySet,
        PositionHashing,
        KeySequence,
        NoKey,
        Dip_Reweight,
        Split_Reweight,
        Tri_Reweight,
        N_Reweight,
        GumbelMax_Reweight,
        STA_Reweight,
        Unigram_Reweight,
        EXP_edit_Reweight,
        ITS_edit_Reweight,
        SynthID_Text_Reweight,
        WatermarkLogitsProcessor_Kuditipudi_OriImplement,
    )

    from ..lm_watermarking.watermark_processor import (
        WatermarkLogitsProcessor as WatermarkLogitsProcessor_John,
    )

    import random
    import copy
    
    key_set_size = 512
    skip_budget = 10
    random.seed(42)
    private_key = random.getrandbits(1024).to_bytes(128, "big")
    
    watermark_key_list = [
        NGramHashing(PrevN_ContextCodeExtractor(2), ignore_history=False)
    ]
    # reweight_list=[Split_Reweight(0)]
    
    
    # reweight_list=[STA_Reweight(gamma=0.5),
    #                GumbelMax_Reweight(is_baseline=True),
    #                GumbelMax_Reweight(is_baseline=False),
    #                Dip_Reweight(alpha=0.5),
    #                Dip_Reweight(alpha=0.4),
    #                Dip_Reweight(alpha=0.3),
    #                Unigram_Reweight(delta=0.5,gamma=0.5),
    #                Unigram_Reweight(delta=1.0,gamma=0.5),
    #                Unigram_Reweight(delta=1.5,gamma=0.5),
    #                Unigram_Reweight(delta=2.0,gamma=0.5)]
    # reweight_list=[N_Reweight(100)]
    
    # reweight_list=[SynthID_Text_Reweight(m=30)]
    
    
    # n_list={'meta-llama/Llama-2-7b-chat-hf':[2, 3, 4, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 4000, 8000, 16000, 32000],
    #         'meta-llama/Llama-3.2-3B-Instruct':[2, 3, 4, 5, 10, 20, 50, 100, 167, 501, 1002, 2004, 4008, 8016, 16032, 32064, 64128, 128256],
    #         'mistralai/Mistral-7B-Instruct-v0.3':[2, 3, 4, 5, 8, 10, 16, 20, 32, 50, 64, 100, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768],
    #         'microsoft/Phi-3.5-mini-instruct':[2, 3, 4, 5, 10, 20, 50, 100, 167, 501, 1002, 2004, 4008, 8016, 16032, 32064],
    #         'deepseek-ai/deepseek-llm-7b-chat':[2, 3, 4, 5, 10, 20, 50, 100, 200, 400, 800, 1600, 3200, 6400, 12800, 25600, 51200, 102400],
    #         'Qwen/Qwen2.5-3B-Instruct':[2, 3, 4, 5, 10, 20, 50, 100, 250, 500, 1187, 2374, 4748, 9496, 18992, 37984, 75968, 151936]}
    n_list = [10, 20, 50, 100]

    # model_str='meta-llama/Llama-2-7b-chat-hf' #32000  2 3 4 5 10 20 50 100 200 500 1000 2000 4000 8000 16000 32000 # 16
    # model_str='meta-llama/Llama-3.2-3B-Instruct' #128000 2 3 4 5 10 20 50 100 200 500 1000 2000 4000 8000 16000 32000 64000 128000 # 18

    # model_str='mistralai/Mistral-7B-Instruct-v0.3' #32768 2 3 4 5 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 #17
    
    # microsoft/Phi-3.5-mini-instruct ?32064 good enough # 2 3 4 5 10 20 50 167 501 1002 2004 4008 8016 16032 32064 #15
    # deepseek-ai/deepseek-llm-7b-chat ? 102400   2 3 4 5 10 20 50 100 200 400 800 1600 3200 6400 12800 25600 51200 102400  #18
    
    reweight_list_stanford=[]
    reweight_list=[]
    if reweight_type=='ITS':
        reweight_list_stanford=[ITS_edit_Reweight()]
    elif reweight_type=='EXP':
        reweight_list_stanford=[EXP_edit_Reweight()]
    elif reweight_type=='baselines':
        reweight_list=[Unigram_Reweight(delta=0.5,gamma=0.5),
                        Unigram_Reweight(delta=1.0,gamma=0.5),
                        Unigram_Reweight(delta=1.5,gamma=0.5),
                        Unigram_Reweight(delta=2.0,gamma=0.5),
                        Dip_Reweight(alpha=0.5),
                        Dip_Reweight(alpha=0.4),
                        Dip_Reweight(alpha=0.3),
                        STA_Reweight(gamma=0.5)]
    elif reweight_type=='main_exp':
        reweight_list=[Unigram_Reweight(delta=0.5,gamma=0.5),
                        Unigram_Reweight(delta=1.0,gamma=0.5),
                        Unigram_Reweight(delta=1.5,gamma=0.5),
                        Unigram_Reweight(delta=2.0,gamma=0.5),
                        Dip_Reweight(alpha=0.5),
                        Dip_Reweight(alpha=0.4),
                        Dip_Reweight(alpha=0.3),
                        STA_Reweight(gamma=0.5)]
        reweight_list+=[N_Reweight(n) for n in n_list]
    elif reweight_type=='nmark':
        reweight_list=[N_Reweight(n) for n in n_list]
        
    elif reweight_type=='test':
        reweight_list=[N_Reweight(20)]
        # raise NotImplementedError
    elif reweight_type=='GumbelMax':
        reweight_list=[GumbelMax_Reweight(is_baseline=True),GumbelMax_Reweight(is_baseline=False)]
    elif reweight_type=='SynthID':
        reweight_list=[SynthID_Text_Reweight(m=30)]
    else:
        print('Unknown reweight_type: ',reweight_type)
        
        
        
    

    wm_wps = []
    
    '''
    EXP_edit and ITS_edit
    '''
    for reweight in reweight_list_stanford:
        wm_wps.append(
            WatermarkLogitsProcessor_Kuditipudi_OriImplement(
                key_set_size=key_set_size,
                reweight=reweight)
            )
    
    
    '''
    Commmon WatermarkLogitsProcessor
    '''
    for wm_key in watermark_key_list:
        for reweight in reweight_list:
            wm_wps.append(
                WatermarkLogitsProcessor(
                    private_key,
                    reweight=copy.deepcopy(reweight),
                    watermark_key_list=[copy.deepcopy(wm_key)],
                )
            )
            
    john_wps = [
            WatermarkLogitsProcessor_John(
                vocab_size=0,  # placeholder
                gamma=0.5,
                delta=delta,
                seeding_scheme="simple_1",
            )
            for delta in [0.5, 1.0, 1.5, 2.0]
        ]


    # baseline_wp = WatermarkLogitsProcessor_Baseline()  # no watermark baseline    
    # return [None,baseline_wp,*wm_wps]
    # return [None]
    
    if reweight_type=='baselines' or reweight_type=='main_exp':
        return [*wm_wps,*john_wps]
    else:
        return [*wm_wps]


def get_num_gpus():
    import torch

    num_gpus = torch.cuda.device_count()
    return num_gpus


def batched_wp_task_worker(tq, get_in_ds, reweight_type, batch_size=8):
    ds = get_in_ds()

    from .common import get_wps

    wps = get_wps(reweight_type)

    from tqdm import tqdm

    for batch in tqdm(ds.iter(batch_size=batch_size), total=len(ds) // batch_size):
        for wp in wps:
            tq.put({"batch": batch, "watermark_processor": wp})


def merged_task_worker(
    get_in_ds,
    output_filepath,
    tq,
    batch_size=8,
    watermark_only=False,
    wh_only=False,
    no_gumbel=False,
    beta_only=False
):
    in_ds = get_in_ds()

    from datasets import load_dataset

    out_ds = load_dataset("json", data_files={"test": output_filepath})["test"]
    out_ds = out_ds.sort("id")

    dss, wps = add_reference(in_ds, out_ds)

    from tqdm import tqdm

    for ds, wp_str in zip(dss, wps):
        if watermark_only:
            if "John" in wp_str or "None" == wp_str:
                continue
        if wh_only:
            if ", True)" in wp_str:
                continue
        if no_gumbel:
            if "Gumbel" in wp_str:
                continue
        if beta_only:
            if "Beta_Reweight" not in wp_str:
                continue
            
        for batch in tqdm(ds.iter(batch_size=batch_size), total=len(ds) // batch_size):
            tq.put(batch)   


def log(line: dict, f):
    import json

    f.write(json.dumps(line))
    f.write("\n")
    f.flush()


def simple_store_worker(path, rq, rqe):
    import os

    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))

    from queue import Empty

    with open(path, "w") as f:
        while not (rqe.is_set() and rq.empty()):
            try:
                result = rq.get(timeout=1)
            except Empty as e:
                continue
            assert isinstance(result, dict)
            if result == {}:
                continue
            if isinstance(next(iter(result.values())), list):
                assert all([isinstance(v, list) for v in result.values()])
                lens = [len(v) for v in result.values()]
                assert all([l == lens[0] for l in lens])
                for i in range(lens[0]):
                    log({k: v[i] for k, v in result.items()}, f)
            else:
                log(result, f)


def group_batch(batch):
    return {k: [v] for k, v in batch.items()}

from typing import Union
def tokenize_batch(example, tokenizer, fields=["input"], max_length: Union[int,dict] = 512):
    result = {}

    if tokenizer.name_or_path == "facebook/mbart-large-en-ro":
        tokenizer.tgt_lang = "ro_RO"

    for field in fields:
        if field in example:
            kwargs = {}
            if isinstance(max_length, dict):
                kwargs["max_length"] = max_length[field]
            else:
                kwargs["max_length"] = max_length
            if field in ["output", "reference"]:
                kwargs["text_target"] = example[field]
            else:
                kwargs["text"] = example[field]
            if field == "output":
                kwargs["add_special_tokens"] = False
            result[field] = tokenizer(
                **kwargs,
                truncation=True,
                padding=True,
                return_tensors="pt",
            )

    return result


def set_spawn():
    import torch.multiprocessing as mp

    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass


def remove_tailing_pad_s(s: str):
    index = s.find("<pad>")
    if index == -1:
        return s
    else:
        return s[:index]


def remove_tailing_pad(strs: list[str]):
    return [remove_tailing_pad_s(s) for s in strs]


def transformer_worker(tq, tqe, rq, gpu_id, model_str, generation_kwargs={}):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    from new_unbiased_watermark import patch_model

    model = AutoModelForSeq2SeqLM.from_pretrained(model_str).to(f"cuda:{gpu_id}")
    patch_model(model)
    tokenizer = AutoTokenizer.from_pretrained(model_str)

    from queue import Empty
    
    model.eval()
    with torch.no_grad():
        while not (tqe.is_set() and tq.empty()):
            try:
                task = tq.get(timeout=1)
            except Empty as e:
                continue
            batch = task["batch"]
            
            # print(batch_size)
            # raise NotADirectoryError
            tbatch = tokenize_batch(batch, tokenizer)
            wp = task["watermark_processor"]
            lps = []
            if wp is not None:
                if 'reset_watermark_key' in dir(wp):
                    batch_size=len(batch['id'])
                    wp.reset_watermark_key(batch_size)
                    
                # if "reset_history" in dir(wp):
                #     batch_size=len(batch['id'])
                #     wp.reset_history(batch_size)
                if "vocab_size" in dir(wp):
                    wp.vocab_size = model.config.vocab_size
                # if "reset_skip_cnt" in dir(wp):
                #     wp.reset_skip_cnt()
                lps.append(wp)

            # for reproducibility and sufficient randomness
            import hashlib
            hash = hashlib.sha256()
            hash.update((str(batch["id"])+repr(wp)).encode("utf-8"))
            seed = hash.digest()
            seed = int.from_bytes(seed, "big") % (2**32 - 1)

            set_seed(seed)
            outputs_ids = model.generate(
                tbatch["input"]["input_ids"].to(device=model.device),
                attention_mask=tbatch["input"]["attention_mask"].to(device=model.device),
                do_sample=True,
                num_beams=1,
                top_k=0,   # default
                top_p=1,
                length_penalty=1,
                early_stopping=False,
                logits_warper=LogitsProcessorList(lps),
                **generation_kwargs,
            )
            outputs = tokenizer.batch_decode(outputs_ids, skip_special_tokens=False)
            outputs = remove_tailing_pad(outputs)
            display_outputs = tokenizer.batch_decode(outputs_ids, skip_special_tokens=True)
            wp_str = repr(wp)
            rq.put(
                {
                    "output": outputs,
                    "display_output": display_outputs,
                    "id": batch["id"],
                    "reference_id": batch['reference_id'],
                    "watermark_processor": [wp_str] * len(outputs),
                }
            )


def add_reference(in_ds, out_ds):
    """assuming ordered by ids"""
    wp_types = set(out_ds["watermark_processor"])

    s_out_dss = []
    for wp_type in wp_types:
        s_out_ds = out_ds.filter(lambda x: x["watermark_processor"] == wp_type)
        assert len(s_out_ds) == len(in_ds), f"{len(s_out_ds)}, {len(in_ds)}"
        s_out_ds = s_out_ds.add_column("input", in_ds["input"])
        s_out_ds = s_out_ds.add_column("reference", in_ds["reference"])
        s_out_dss.append(s_out_ds)
    from datasets import concatenate_datasets

    return s_out_dss, wp_types

def bertscore_worker(tq, tqe, rq, gpu_id=0):
    import bert_score

    scorer = bert_score.BERTScorer(
        lang="de",
        rescale_with_baseline=True,
        device=f"cuda:{gpu_id}",
        use_fast_tokenizer=True,
    )

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        (P, R, F) = scorer.score(batch["display_output"], batch["reference"])
        rq.put(
            {
                **batch,
                "bertscore.precision": P.tolist(),
                "bertscore.recall": R.tolist(),
                "bertscore.f1": F.tolist(),
            }
        )


def rouge_worker(tq, tqe, rq):
    import evaluate

    rouge = evaluate.load("rouge")

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        rouge_scores = rouge.compute(
            predictions=batch["display_output"],
            references=batch["reference"],
            rouge_types=["rouge1", "rouge2", "rougeL"],
            use_stemmer=True,
            use_aggregator=False,
        )
        rq.put({**rouge_scores, **batch})


def remove_text_worker(tq, tqe, rq):
    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        for f in ["input", "output", "reference", "display_output"]:
            if f in batch:
                del batch[f]
        rq.put(batch)


import torch


@torch.no_grad()
def get_ppl(model, tbatch):
    input_ids = tbatch["input"]["input_ids"].to(model.device)
    attention_mask = tbatch["input"]["attention_mask"].to(model.device)
    decoder_input_ids = tbatch["output"]["input_ids"][..., :-1].to(model.device)
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
    )

    from torch.nn import CrossEntropyLoss

    loss_fct = CrossEntropyLoss(reduction="none")

    #  output.logits: [batch_size, sequence_length, vocab_size]
    #  labels: [batch_size, sequence_length]
    labels = tbatch["output"]["input_ids"][..., 1:].to(model.device)
    shape = labels.shape
    #  loss: [batch_size, sequence_length]
    losses = loss_fct(
        outputs.logits.reshape(-1, outputs.logits.shape[-1]),
        labels.view(-1),
    ).reshape(shape)
    label_attention_mask = tbatch["output"]["attention_mask"][..., 1:].to(model.device)
    #  loss: [batch_size]
    losses = (losses * label_attention_mask.float()).sum(
        dim=-1
    ) / label_attention_mask.sum(dim=-1)
    ppl = torch.exp(losses).cpu().tolist()
    return ppl


def ppl_worker(tq, tqe, rq, gpu_id, oracle_model_str):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    model = AutoModelForSeq2SeqLM.from_pretrained(oracle_model_str).to(f"cuda:{gpu_id}")
    tokenizer = AutoTokenizer.from_pretrained(oracle_model_str)

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        tbatch = tokenize_batch(
            batch,
            tokenizer,
            ["input", "output"],
        )
        ppl = get_ppl(model, tbatch)
        with torch.cuda.device(model.device):
            torch.cuda.empty_cache()

        rq.put(
            {
                **batch,
                "ppl": ppl,
            }
        )
        
@torch.no_grad()
def get_quantile(vocab_size, tbatch, wp, device, test_config={}, la_wp=None):
    input_ids = tbatch["input"]["input_ids"].to(device)
    attention_mask = tbatch["input"]["attention_mask"].to(device)
    #  labels : [batch_size, output_sequence_length-1]
    labels = tbatch["output"]["input_ids"][..., 1:].to(device)
    label_attention_mask = tbatch["output"]["attention_mask"][..., 1:].to(device)
    #  decoder_input_ids : [batch_size, output_sequence_length-1]
    decoder_input_ids = tbatch["output"]["input_ids"][..., :-1].to(device)
    quantile = torch.zeros(
        decoder_input_ids.shape,device = device)
    for i in range(decoder_input_ids.size(1)-1):
        pre = decoder_input_ids[:, : i + 1]
        cur_token = decoder_input_ids[:, i+1]
        out = wp.get_green_token_quantile(pre,vocab_size,cur_token)
        quantile[:, i] = torch.stack(out).reshape(-1)
#     import torch.nn.functional as F
#     scores = scores * label_attention_mask
    return quantile, label_attention_mask


def beta_score_worker(tq, tqe, rq, gpu_id, model_str):
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    device = f"cuda:{gpu_id}"
    tokenizer = AutoTokenizer.from_pretrained(model_str)
    vocab_size = max(tokenizer.vocab.values())
#     print(vocab_size)
    from queue import Empty
    
    def score_func(quantiles,lens,mode='test'):
        '''
        params:
            quantiles: [batch_size, max_len]
        '''
        if mode=='linear':
            return torch.sum(quantiles,dim=-1),lens/2
        
        if mode=='test':
            return torch.sum(quantiles>0.5,dim=-1),lens/2
        
        
        return NotImplementedError

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        assert len(set(batch["watermark_processor"])) == 1

        wp_str = batch["watermark_processor"][0]
        
        # print(wp_str)
        # raise
        # print(wp_str)
        # raise NotImplementedError
        from new_unbiased_watermark import (
            Gamma_Reweight,
            Delta_Reweight,
            DeltaGumbel_Reweight,
            WatermarkLogitsProcessor,
            PrevN_ContextCodeExtractor,
            Beta_Reweight,
            NoKey,
            PositionHashing,
            NGramHashing,
            KeySequence,
            TokenSkipping
        )
        
        
        wp = eval(wp_str)
        wp.reset_watermark_key(len(batch["watermark_processor"]))
        wp.ignore_history = True

        la_wp = None
#         print("start tokenizer")
        tbatch = tokenize_batch(
            batch,
            tokenizer,
            ["input", "output"],
        )
        # score: [batch_size, sequence_length, query_size]
        # entropy: [batch_size, sequence_length]
        # label_attention_mask: [batch_size, sequence_length]
        quantiles, label_attention_mask = get_quantile(
            vocab_size, tbatch, wp, device, la_wp=la_wp
        )
        
        
        print('quantiles shape:',quantiles.shape)
        # raise NotImplementedError
        
        

        
        
        
        #calculate the corresponding score here!
        
        quantiles = quantiles* label_attention_mask
        
        cum_label_attention_mask = torch.cumsum(label_attention_mask, dim=-1)
        lens = cum_label_attention_mask[:, -1]
        _lens_m_1 = torch.argmax(
            cum_label_attention_mask,
            dim=-1,
        )
#         print(_lens_m_1)
        assert torch.all(lens == _lens_m_1 + 1)
        
        raw_score,expected_value=score_func(quantiles,lens)
        seq_len = torch.sum(label_attention_mask,dim=-1,keepdim=False)
        final_score=(raw_score-expected_value)/torch.sqrt(seq_len)



        lens = lens.cpu().tolist()

        rq.put(
            {
                **batch,
                "beta_score":final_score.cpu().tolist(),
                # "score": score.cpu().tolist(),
#                 "gamma_list": gamma_list,
                # "best_score": best_score,
                # "best_app_score":best_app_score,
                "lens": lens,
                # "all_scores": prob_col.cpu().tolist(),
                # "app_scores": score_col.cpu().tolist(),
            }
        )
    


@torch.no_grad()
def get_score(model, tbatch, wp, scorer, test_config={}, la_wp=None):
    raise NotImplementedError
    input_ids = tbatch["input"]["input_ids"].to(model.device)
    attention_mask = tbatch["input"]["attention_mask"].to(model.device)
    #  labels : [batch_size, output_sequence_length-1]
    labels = tbatch["output"]["input_ids"][..., 1:].to(model.device)
    label_attention_mask = tbatch["output"]["attention_mask"][..., 1:].to(model.device)
    #  decoder_input_ids : [batch_size, output_sequence_length-1]
    decoder_input_ids = tbatch["output"]["input_ids"][..., :-1].to(model.device)
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
    )

    from transformers import GenerationConfig

    generation_config = GenerationConfig.from_model_config(model.config)
    if "temperature" in test_config:
        generation_config.temperature = test_config["temperature"]
    if "top_k" in test_config:
        generation_config.top_k = test_config["top_k"]
    logits_processor = model._get_logits_processor(
        generation_config,
        input_ids_seq_length=labels.shape[-1],
        encoder_input_ids=input_ids,
        prefix_allowed_tokens_fn=None,
        logits_processor=[],
    )
    logits_warper = model._get_logits_warper(generation_config)

    #  logits: [batch_size, output_sequence_length-1, vocab_size]
    logits = outputs.logits
    del outputs
    del input_ids
    del attention_mask
    with torch.cuda.device(model.device):
        torch.cuda.empty_cache()
    old_logits = torch.clone(logits)
    scores = torch.zeros(
        logits.shape[:-1] + (scorer.query_size(),), device=logits.device
    )
    if la_wp is not None:
        la_scores = torch.zeros(logits.shape[:-1], device=logits.device)
    else:
        la_scores = None
    for i in range(logits.size(1)):
        pre = decoder_input_ids[:, : i + 1]
        t = logits[:, i]
        t = logits_processor(pre, t)
        t = logits_warper(pre, t)
        old_logits[:, i] = t
        new_logits = wp(pre, t)
        scores[:, i] = wp.get_score(labels[:, i], old_logits[:, i], new_logits, scorer)
        if la_wp is not None:
            la_scores[:, i] = la_wp.get_la_score(pre, labels[:, i], logits.shape[-1])
    del logits

    # compute entropy
    import torch.nn.functional as F

    _logits = F.log_softmax(old_logits, dim=-1)
    _logits.nan_to_num_()
    entropy = -(torch.exp(_logits) * _logits).sum(dim=-1)

    scores = scores * label_attention_mask.unsqueeze(-1)
    if la_wp is not None:
        la_scores = la_scores * label_attention_mask
    entropy = entropy * label_attention_mask
    return scores, entropy, label_attention_mask, la_scores


def score_worker(tq, tqe, rq, gpu_id, model_str):
    raise NotImplementedError
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    model = AutoModelForSeq2SeqLM.from_pretrained(model_str).to(f"cuda:{gpu_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_str)

    from unbiased_watermark import RobustLLR_Score_Batch_v1, RobustLLR_Score_Batch_v2

    grid_size = 10
    dist_qs = [i / grid_size for i in range(0, grid_size + 1)]
    scorer = RobustLLR_Score_Batch_v2.from_grid([0.0], dist_qs)

    from queue import Empty

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        assert len(set(batch["watermark_processor"])) == 1

        from unbiased_watermark import (
            Delta_Reweight,
            Gamma_Reweight,
            DeltaGumbel_Reweight,
            WatermarkLogitsProcessor,
            PrevN_ContextCodeExtractor,
        )

        wp_str = batch["watermark_processor"][0]
        wp = eval(wp_str)
        wp.ignore_history = True
        if "get_la_score" in dir(wp.reweight):
            import copy

            la_wp = copy.deepcopy(wp)
            la_wp.ignore_history = False
        else:
            la_wp = None

        tbatch = tokenize_batch(
            batch,
            tokenizer,
            ["input", "output"],
        )
        # score: [batch_size, sequence_length, query_size]
        # entropy: [batch_size, sequence_length]
        # label_attention_mask: [batch_size, sequence_length]
        score, entropy, label_attention_mask, la_score = get_score(
            model, tbatch, wp, scorer, la_wp=la_wp
        )

        import gc

        gc.collect()
        # score: [batch_size, query_size]
        sum_score = score.sum(-2)
        # best_index: [batch_size]
        best_index = torch.argmax(sum_score, dim=-1)
        best_dist_q = [dist_qs[i] for i in best_index.cpu().tolist()]
        best_sum_score = (
            torch.gather(
                sum_score,
                -1,
                best_index.unsqueeze(-1),
            )
            .squeeze(-1)
            .cpu()
            .tolist()
        )

        # best_score: [batch_size, sequence_length]
        best_score = (
            torch.gather(
                score,
                -1,
                best_index.unsqueeze(-1).unsqueeze(-1).expand(-1, score.size(-2), -1),
            )
            .squeeze(-1)
            .cpu()
            .tolist()
        )
        # la_score: [batch_size, sequence_length]
        if la_score is not None:
            sum_la_score = la_score.sum(-1).cpu().tolist()
            la_score = la_score.cpu().tolist()
        else:
            la_score = [[]] * len(best_dist_q)
            sum_la_score = [None] * len(best_dist_q)
        #  lens: [batch_size]
        cum_label_attention_mask = torch.cumsum(label_attention_mask, dim=-1)
        lens = cum_label_attention_mask[:, -1]
        #  assert attention is like 11110000
        _lens_m_1 = torch.argmax(
            cum_label_attention_mask,
            dim=-1,
        )
        assert torch.all(lens == _lens_m_1 + 1)

        lens = lens.cpu().tolist()
        best_score = [best_score[i][: lens[i]] for i in range(len(best_score))]
        la_score = [la_score[i][: lens[i]] for i in range(len(la_score))]
        entropy = entropy.cpu().tolist()
        entropy = [entropy[i][: lens[i]] for i in range(len(entropy))]

        rq.put(
            {
                **batch,
                "best_dist_q": best_dist_q,
                "best_sum_score": best_sum_score,
                "best_score": best_score,
                "la_score": la_score,
                "sum_la_score": sum_la_score,
                "lens": lens,
                "entropy": entropy,
            }
        )


def score_worker2(tq, tqe, rq, gpu_id, test_config):
    raise NotImplementedError
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, set_seed
    from transformers import LogitsProcessorList, TemperatureLogitsWarper

    model_str = test_config["model_str"]
    model = AutoModelForSeq2SeqLM.from_pretrained(model_str).to(f"cuda:{gpu_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_str)

    from unbiased_watermark import RobustLLR_Score_Batch_v1, RobustLLR_Score_Batch_v2

    grid_size = 10
    dist_qs = [i / grid_size for i in range(0, grid_size + 1)]
    scorer = RobustLLR_Score_Batch_v2.from_grid([0.0], dist_qs)

    from queue import Empty

    wp_str = test_config["wp_str"]
    from unbiased_watermark import (
        Delta_Reweight,
        Gamma_Reweight,
        WatermarkLogitsProcessor,
        PrevN_ContextCodeExtractor,
    )

    wp = eval(wp_str)
    wp.ignore_history = True

    while not (tqe.is_set() and tq.empty()):
        try:
            batch = tq.get(timeout=1)
        except Empty as e:
            continue
        assert len(set(batch["watermark_processor"])) == 1
        if batch["watermark_processor"][0].split(",")[-3] != wp_str.split(",")[-3]:
            # not same watermark type
            continue

        from unbiased_watermark import (
            Delta_Reweight,
            Gamma_Reweight,
            WatermarkLogitsProcessor,
            PrevN_ContextCodeExtractor,
        )

        if "no_input" in test_config and test_config["no_input"]:
            batch["input"] = [""] * len(batch["input"])
        tbatch = tokenize_batch(
            batch,
            tokenizer,
            ["input", "output"],
        )
        # score: [batch_size, sequence_length, query_size]
        # entropy: [batch_size, sequence_length]
        # label_attention_mask: [batch_size, sequence_length]
        score, entropy, label_attention_mask, _ = get_score(
            model, tbatch, wp, scorer, test_config
        )

        import gc

        gc.collect()
        # score: [batch_size, query_size]
        sum_score = score.sum(-2)
        # best_index: [batch_size]
        best_index = torch.argmax(sum_score, dim=-1)
        best_dist_q = [dist_qs[i] for i in best_index.cpu().tolist()]
        best_sum_score = (
            torch.gather(
                sum_score,
                -1,
                best_index.unsqueeze(-1),
            )
            .squeeze(-1)
            .cpu()
            .tolist()
        )

        # best_score: [batch_size, sequence_length]
        best_score = (
            torch.gather(
                score,
                -1,
                best_index.unsqueeze(-1).unsqueeze(-1).expand(-1, score.size(-2), -1),
            )
            .squeeze(-1)
            .cpu()
            .tolist()
        )
        #  lens: [batch_size]
        cum_label_attention_mask = torch.cumsum(label_attention_mask, dim=-1)
        lens = cum_label_attention_mask[:, -1]
        #  assert attention is like 11110000
        _lens_m_1 = torch.argmax(
            cum_label_attention_mask,
            dim=-1,
        )
        assert torch.all(lens == _lens_m_1 + 1)

        lens = lens.cpu().tolist()
        best_score = [best_score[i][: lens[i]] for i in range(len(best_score))]
        entropy = entropy.cpu().tolist()
        entropy = [entropy[i][: lens[i]] for i in range(len(entropy))]

        rq.put(
            {
                **batch,
                "best_dist_q": best_dist_q,
                "best_sum_score": best_sum_score,
                "best_score": best_score,
                "lens": lens,
                "entropy": entropy,
                "test_config": [test_config] * len(batch["input"]),
            }
        )
