import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import List, Optional
import argparse
from scripts.utils import load_single_dataset, save_dataset, split_batch
import datasets
from tqdm import tqdm


# @torch.no_grad()
# def contrastive_decoding(
#     max_prompt_length: int,
#     max_response_length: int,
#     prompts: list,
#     expert_model,
#     base_model,
#     tokenizer,
#     candidate_threshold,
#     score_threshhold_value,
#     eos_token_id=None,
#     device='cuda'
# ):
#     batch_size = len(prompts)
#     # 1. Tokenize & left pad
#     encoded = tokenizer(prompts, return_tensors='pt', padding='longest', truncation=True, max_length=max_prompt_length)
#     input_ids = encoded['input_ids']  # [B, L]
#     attention_mask = encoded['attention_mask']
#     pad_token_id = tokenizer.pad_token_id
#     eos_token_id = eos_token_id if eos_token_id is not None else tokenizer.eos_token_id

#     def left_pad(tensor, pad_value):
#         seq_lens = attention_mask.sum(-1)
#         max_len = tensor.size(1)
#         new_tensor = torch.full_like(tensor, pad_value)
#         for i, l in enumerate(seq_lens):
#             new_tensor[i, -l:] = tensor[i, -l:]
#         return new_tensor

#     input_ids = left_pad(input_ids, pad_token_id)
#     attention_mask = (input_ids != pad_token_id).long()
#     position_ids = attention_mask.cumsum(-1) - 1
#     position_ids[input_ids == pad_token_id] = 0

#     input_ids = input_ids.to(device)
#     attention_mask = attention_mask.to(device)
#     position_ids = position_ids.to(device)

#     expert_past = None
#     base_past = None
#     generated = [[] for _ in range(batch_size)]
#     done = [False] * batch_size

#     for step in range(max_response_length):
#         with torch.no_grad():
#             if expert_model is not None:
#                 expert_outputs = expert_model(
#                     input_ids=input_ids,
#                     attention_mask=attention_mask,
#                     position_ids=position_ids,
#                     use_cache=True,
#                     past_key_values=expert_past,
#                 )
#                 expert_logits = expert_outputs.logits[:, -1, :]
#                 expert_past = expert_outputs.past_key_values
#             else:
#                 expert_logits = None

#             base_outputs = base_model(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 position_ids=position_ids,
#                 use_cache=True,
#                 past_key_values=base_past,
#             )
#             base_logits = base_outputs.logits[:, -1, :]
#             base_past = base_outputs.past_key_values

#             base_logp = F.log_softmax(base_logits, dim=-1)      # [B, vocab]
#             base_prob = base_logp.exp()  # [B, vocab]

#             if expert_logits is not None:
#                 expert_logp = F.log_softmax(expert_logits, dim=-1)  # [B, vocab]

#             next_tokens = []
#             for i in range(batch_size):
#                 if done[i]:
#                     next_tokens.append(eos_token_id)
#                     continue

#                 if expert_logits is None:
#                     # Base-only解码，取最大概率token
#                     next_token = base_prob[i].argmax().item()
#                 else:
#                     # 候选集：base-prob > 0.1，否则取最大prob
#                     prob_i = base_prob[i]
#                     cand_mask = prob_i > candidate_threshold
#                     if cand_mask.sum() == 0:
#                         cands = prob_i.argmax().unsqueeze(0)
#                     else:
#                         cands = cand_mask.nonzero().squeeze(-1)
#                     contrastive_score = (expert_logp[i][cands] - base_logp[i][cands])
#                     idx = contrastive_score.argmax()
#                     next_token = cands[idx].item()

#                 generated[i].append(next_token)
#                 if next_token == eos_token_id:
#                     done[i] = True

#             next_tokens = torch.tensor(next_tokens, dtype=torch.long, device=device).unsqueeze(1)  # [B, 1]
#             input_ids = next_tokens
#             attention_mask = torch.ones_like(input_ids, device=device)
#             position_ids = (position_ids[:, -1:] + 1)  # [B, 1]

#         if all(done):
#             break

#     # Decode
#     decoded = []
#     for prompt, out_ids in zip(prompts, generated):
#         text = tokenizer.decode(out_ids, skip_special_tokens=True)
#         decoded.append(text)

#     return decoded


def contrastive_decoding(
    max_prompt_length: int,
    max_response_length: int,
    prompts: List[str],
    base_model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    expert_model: Optional[PreTrainedModel] = None,
    temperature: float = 1.0,
    candidate_threshold: float = 0.1,
    score_threshhold_value: float = 0.2,
):
    device = base_model.device
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"

    # Step 1: Tokenize prompts with left padding
    batch = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_prompt_length,
    ).to(device)

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    batch_size = input_ids.size(0)
    position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long, device=device).unsqueeze(0).expand_as(input_ids)

    # Initialize response
    generated = [[] for _ in range(batch_size)]
    ended = torch.zeros(batch_size, dtype=torch.bool, device=device)

    past_key_values_base = None
    past_key_values_expert = None

    for step in range(max_response_length):
        model_inputs = {
            "input_ids": input_ids[:, -1:] if step > 0 else input_ids,
            "attention_mask": attention_mask,
            "position_ids": position_ids[:, -1:] if step > 0 else position_ids,
            "use_cache": True,
            "past_key_values": past_key_values_base,
        }
        with torch.no_grad():
            base_outputs = base_model(**model_inputs)
        logits_base = base_outputs.logits[:, -1, :]
        past_key_values_base = base_outputs.past_key_values

        # if step == 0:
        #     breakpoint()  # tokenizer.batch_decode(logits_base.argmax(-1))

        log_probs_base = torch.nn.functional.log_softmax(logits_base / temperature, dim=-1)
        probs_base = log_probs_base.exp()

        # Get expert outputs if provided
        if expert_model is not None:
            model_inputs["past_key_values"] = past_key_values_expert
            with torch.no_grad():
                expert_outputs = expert_model(**model_inputs)
            logits_expert = expert_outputs.logits[:, -1, :]
            past_key_values_expert = expert_outputs.past_key_values
            log_probs_expert = torch.nn.functional.log_softmax(logits_expert / temperature, dim=-1)

        next_tokens = []
        for i in range(batch_size):
            if ended[i]:
                next_tokens.append(tokenizer.pad_token_id)
                continue

            base_logp = log_probs_base[i]
            base_prob = probs_base[i]

            # Step 3: Get candidate set
            candidate_mask = base_prob > candidate_threshold
            if not candidate_mask.any():
                candidate_mask[base_prob.argmax()] = True
            candidate_ids = candidate_mask.nonzero(as_tuple=True)[0]

            # Step 4: Select token
            if expert_model is not None:
                score = log_probs_expert[i][candidate_ids] - base_logp[candidate_ids]
                if score.max() > score_threshhold_value:
                    selected_token = candidate_ids[score.argmax()]
                else:
                    selected_token = candidate_ids[base_prob[candidate_ids].argmax()]
            else:
                selected_token = candidate_ids[base_prob[candidate_ids].argmax()]

            next_tokens.append(selected_token.item())
            generated[i].append(selected_token.item())

            if selected_token.item() == tokenizer.eos_token_id:
                ended[i] = True

        # Update inputs
        next_tokens_tensor = torch.tensor(next_tokens, dtype=torch.long, device=device).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_tokens_tensor], dim=1)
        next_attention = torch.ones((batch_size, 1), dtype=torch.long, device=device)
        attention_mask = torch.cat([attention_mask, next_attention], dim=1)
        next_position = position_ids[:, -1:] + 1
        position_ids = torch.cat([position_ids, next_position], dim=1)

        if ended.all():
            break

    # Decode outputs
    return tokenizer.batch_decode(
        generated,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", required=True, type=str)
    parser.add_argument("--expert_model", required=False, default=None, type=str)
    parser.add_argument("--data", required=False, default=None, type=str)
    parser.add_argument("--prompt_key", required=False, default="prompt", type=str)
    parser.add_argument("--output_key", required=False, default="response", type=str)
    parser.add_argument("--output_jsonl", required=False, default=None, type=str)
    parser.add_argument("--dataset_split", required=False, default="train", type=str)
    parser.add_argument("--batch_size", required=False, default=1, type=int)
    parser.add_argument("--max_prompt_length", required=False, default=512, type=int)
    parser.add_argument("--max_response_length", required=False, default=2048, type=int)
    parser.add_argument("--begin", required=False, default=None, type=int)
    parser.add_argument("--end", required=False, default=None, type=int)
    parser.add_argument("--candidate_threshold", required=False, default=0.1, type=float)
    parser.add_argument("--score_threshhold_value", required=False, default=0.2, type=float)
    args = parser.parse_args()

    # init models, tokenizer, dataset
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model, device_map="auto", torch_dtype=torch.float16).eval()
    expert_model = None
    if args.expert_model is not None:
        expert_model = AutoModelForCausalLM.from_pretrained(args.expert_model, device_map="auto", torch_dtype=torch.float16).eval()
    test_set = load_single_dataset(args.data, dataset_split=args.dataset_split)
    if isinstance(test_set, datasets.DatasetDict):
        test_set = datasets.concatenate_datasets(list(test_set.values()))

    # begin - end
    begin, end = 0, len(test_set)
    if args.begin is not None:
        begin = max(args.begin, begin)
    if args.end is not None:
        end = min(args.end, end)
    test_set = test_set.select(range(begin, end))
    
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    prompts = [tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True) for p in test_set[args.prompt_key]]

    # batched prompts
    batched_prompts = split_batch(prompts, args.batch_size)
    outputs = []

    # generate
    tokenizer.padding_side = "left"
    for batch in tqdm(batched_prompts):
        results = contrastive_decoding(
            max_prompt_length=args.max_prompt_length,
            max_response_length=args.max_response_length,
            prompts=batch,
            expert_model=expert_model,
            base_model=base_model,
            tokenizer=tokenizer,
            candidate_threshold=args.candidate_threshold,
            score_threshhold_value=args.score_threshhold_value,
        )
        outputs.extend(results)

    # write back the results
    test_set = test_set.add_column(args.output_key, outputs)
    save_dataset(test_set, args.output_jsonl)
    test_set1 = test_set.filter(lambda x: len(x[args.output_key]) == 0)
    save_dataset(test_set1, args.output_jsonl + "debug.jsonl")


"""



CUDA_VISIBLE_DEVICES=6 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/checkpoint-7179 \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/checkpoint-7179/math-test-contrastive_decoding.jsonl \
    --batch_size 8 \
    --begin 0 \
    --end 105


CUDA_VISIBLE_DEVICES=7 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/7b_model/Qwen-Qwen3-0.6B-nothink \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/Qwen-Qwen3-0.6B-nothink/math-test-contrastive_decoding.jsonl \
    --batch_size 8 \
    --begin 0 \
    --end 105







CUDA_VISIBLE_DEVICES=6 ./.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct/math-test-tanlan.jsonl \
    --batch_size 2 

    

CUDA_VISIBLE_DEVICES=6 ./.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-tanlan.jsonl \
    --batch_size 2 

CUDA_VISIBLE_DEVICES=2,3 ./.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-contrastive_decoding_421_842.jsonl \
    --batch_size 2 \
    --begin 421 \
    --end 842
    







    
CUDA_VISIBLE_DEVICES=0,1 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-contrastive_decoding_0_210_new.jsonl \
    --batch_size 2 \
    --begin 0 \
    --end 210 \
    --candidate_threshold 0.1

CUDA_VISIBLE_DEVICES=2,3 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-contrastive_decoding_210_421_new.jsonl \
    --batch_size 2 \
    --begin 210 \
    --end 421 \
    --candidate_threshold 0.1

CUDA_VISIBLE_DEVICES=4,5 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-contrastive_decoding_421_631_new.jsonl \
    --batch_size 2 \
    --begin 421 \
    --end 631 \
    --candidate_threshold 0.1
    
CUDA_VISIBLE_DEVICES=6,7 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/7b_model/PRIME-RL-EurusPRM-Stage1 \
    --base_model ~/7b_model/Qwen-Qwen2.5-Math-7B-Instruct \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/7b_model/PRIME-RL-EurusPRM-Stage1/math-test-contrastive_decoding_631_842_new.jsonl \
    --batch_size 2 \
    --begin 631 \
    --end 842 \
    --candidate_threshold 0.1


    



























CUDA_VISIBLE_DEVICES=0 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_0_105_new.jsonl \
    --batch_size 8 \
    --begin 0 \
    --end 105 &

CUDA_VISIBLE_DEVICES=1 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_105_211_new.jsonl \
    --batch_size 8 \
    --begin 105 \
    --end 211 &

CUDA_VISIBLE_DEVICES=2 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_211_315_new.jsonl \
    --batch_size 8 \
    --begin 211 \
    --end 315 &

CUDA_VISIBLE_DEVICES=3 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_315_420_new.jsonl \
    --batch_size 8 \
    --begin 315 \
    --end 420 &

    
CUDA_VISIBLE_DEVICES=4 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_420_525_new.jsonl \
    --batch_size 8 \
    --begin 420 \
    --end 525 &

CUDA_VISIBLE_DEVICES=5 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_525_630_new.jsonl \
    --batch_size 8 \
    --begin 525 \
    --end 630 &

CUDA_VISIBLE_DEVICES=6 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_630_735_new.jsonl \
    --batch_size 8 \
    --begin 630 \
    --end 735 &

CUDA_VISIBLE_DEVICES=7 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/math-test-contrastive_decoding_735_842_new.jsonl \
    --batch_size 8 \
    --begin 735 \
    --end 842 &


    



    





CUDA_VISIBLE_DEVICES=0 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_0_105_new.jsonl \
    --batch_size 4 \
    --begin 0 \
    --end 105 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=1 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_105_211_new.jsonl \
    --batch_size 4 \
    --begin 105 \
    --end 211 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=2 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_211_315_new.jsonl \
    --batch_size 4 \
    --begin 211 \
    --end 315 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=3 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_315_420_new.jsonl \
    --batch_size 4 \
    --begin 315 \
    --end 420 \
    --candidate_threshold 0.2 &

    
CUDA_VISIBLE_DEVICES=4 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_420_525_new.jsonl \
    --batch_size 4 \
    --begin 420 \
    --end 525 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=5 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_525_630_new.jsonl \
    --batch_size 4 \
    --begin 525 \
    --end 630 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=6 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_630_735_new.jsonl \
    --batch_size 4 \
    --begin 630 \
    --end 735 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=7 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/verl_adpa/checkpoints/test/training_reward_model/global_step_90/reward/math-test-contrastive_decoding_735_842_new.jsonl \
    --batch_size 4 \
    --begin 735 \
    --end 842 \
    --candidate_threshold 0.2 &




    






    

CUDA_VISIBLE_DEVICES=0 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_0_105_new.jsonl \
    --batch_size 4 \
    --begin 0 \
    --end 105 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=1 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_105_211_new.jsonl \
    --batch_size 4 \
    --begin 105 \
    --end 211 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=2 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_211_315_new.jsonl \
    --batch_size 4 \
    --begin 211 \
    --end 315 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=3 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_315_420_new.jsonl \
    --batch_size 4 \
    --begin 315 \
    --end 420 \
    --candidate_threshold 0.2 &

    
CUDA_VISIBLE_DEVICES=4 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_420_525_new.jsonl \
    --batch_size 4 \
    --begin 420 \
    --end 525 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=5 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_525_630_new.jsonl \
    --batch_size 4 \
    --begin 525 \
    --end 630 \
    --candidate_threshold 0.2 &

CUDA_VISIBLE_DEVICES=6 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/every_ce_reward_model/checkpoint-700/math-test-contrastive_decoding_630_735_new.jsonl \
    --batch_size 4 \
    --begin 630 \
    --end 735 \
    --candidate_threshold 0.2 &

    






    
CUDA_VISIBLE_DEVICES=2 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/eurusprm_everyce_beta_11/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/eurusprm_everyce_beta_11/checkpoint-700/math-test-contrastive_decoding_0_421_new.jsonl \
    --batch_size 4 \
    --begin 0 \
    --end 421 \
    --candidate_threshold 0.2 &
    
CUDA_VISIBLE_DEVICES=3 ~/verl_cs/.conda/bin/python \
~/verl_cs/scripts/contrastive_decoding.py \
    --expert_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/eurusprm_everyce_beta_11/checkpoint-700 \
    --base_model ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/ \
    --data ~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-0-480000-not-exceeed-catagory5/valid.parquet \
    --prompt_key prompt \
    --output_key response \
    --output_jsonl ~/LLaMA-Factory-250514/saves/qwen3-0.6B/eurusprm_everyce_beta_11/checkpoint-700/math-test-contrastive_decoding_735_842_new.jsonl \
    --batch_size 4 \
    --begin 421 \
    --end 842 \
    --candidate_threshold 0.2 &






"""