import json
from pathlib import Path

import torch
torch.set_grad_enabled(False)
import torch.nn.functional as F
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, TokensPrompt, TextPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from typing import Union, List
import itertools

import dp_tokenization as dpt
from utils import bytes_to_unicode
from text_conditioning import *
from tqdm.auto import tqdm
from statsmodels.stats.proportion import proportion_confint
from utils import read_jsonl_gz
import random
B = bytes_to_unicode()
Brev = {v:k for k,v in B.items()}

model_dir = Path("Qwen/Qwen3-1.7B-Base")
tcs = TextConditionedSampler(model_dir)

def predict_next_char(tcs: TextConditionedSampler, prompt: str):
    if tcs.btok.normalizer is not None:
        prompt = tcs.btok.normalizer.normalize_str(prompt)
    sampler = tcs.get_bytewise_sampler(batch_size=1)
    sampler.add_context([prompt])
    buf = b''
    while True:
        dist = sampler.get_dists()[0]
        if len(buf) > 4:
            return ' '
        next_byte = bytes([sample_from_logits(
            dist, do_sample=False
        ).item()])
        buf += next_byte
        try:
            return buf.decode()
        except UnicodeDecodeError:
            sampler.add_context([next_byte])
        

def count_ours(text):
    trunk, branches = tcs.streaming_bpe_open(text, inclusive=True)
    def count_branches(branches):
        return 1 + sum(count_branches(subtree) for tid, subtree in branches.items() if tid is not None)
    return len(trunk) + count_branches(branches)

correct = 0
cost = 0
total = 0
text = "REDACTED/madlad_shuffle/zh/zh_clean_0000.jsonl.gz"
N = 100_000
for i, doc in enumerate(tqdm(read_jsonl_gz(text), total=N)):
    prompt = doc['text']
    idx = random.randint(0, min(len(prompt)-1, 500))
    try:
        cost += count_ours(prompt[:idx]) - len(tcs.tokenizer.encode(prompt[:idx]))
        correct += predict_next_char(tcs, prompt[:idx]) == prompt[idx]
        total += 1
    except UnicodeDecodeError:
        pass
    if i == N:
        break