import json
from pathlib import Path

import torch
torch.set_grad_enabled(False)
import torch.nn.functional as F
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, TokensPrompt, TextPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from typing import Union, List
import itertools

import dp_tokenization as dpt
from utils import bytes_to_unicode, read_jsonl_gz
from text_conditioning import *
from tqdm.auto import tqdm
from statsmodels.stats.proportion import proportion_confint

B = bytes_to_unicode()
Brev = {v:k for k,v in B.items()}

model_dir = Path("Qwen/Qwen3-1.7B-Base")
tcs = TextConditionedSampler(model_dir)

def next_char_prob(tcs: TextConditionedSampler, prompt: str, char):
    if tcs.btok.normalizer is not None:
        prompt = tcs.btok.normalizer.normalize_str(prompt)
    sampler = tcs.get_bytewise_sampler(batch_size=1)
    sampler.add_context([prompt])

    total = 0
    for b in char.encode():
        logprobs = torch.log_softmax(sampler.get_dists()[0], 0)
        # print(logprobs.max())
        total += logprobs[b]
        sampler.add_context([bytes([b])])
        
    return total.item()
        

from utils import read_jsonl_zstd
import random
correct = 0
total = 0
text = "REDACTED/madlad_shuffle/zh/zh_clean_0000.jsonl.gz"
N = 100_000
for i, doc in enumerate(tqdm(read_jsonl_gz(text), total=N)):
    prompt = doc['text']
    idx = random.randint(0, min(len(prompt)-1, 500))
    try:
        correct += next_char_prob(tcs, prompt[:idx], prompt[idx])
        # predict_next_char(tcs, prompt[:idx]) == prompt[idx]
        total += 1
    except UnicodeDecodeError:
        pass
    if i == N:
        break