import json
from pathlib import Path

import torch
torch.set_grad_enabled(False)
import torch.nn.functional as F
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, TokensPrompt, TextPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from typing import Union, List
import itertools

import dp_tokenization as dpt
from utils import bytes_to_unicode
from text_conditioning import *
from tqdm.auto import tqdm
from statsmodels.stats.proportion import proportion_confint
from utils import read_jsonl_zstd
import random
B = bytes_to_unicode()
Brev = {v:k for k,v in B.items()}

model_dir = Path("allenai/OLMo-2-0425-1B")
tcs = TextConditionedSampler(model_dir)


def byte_cover(text: bytes, device=None):
    assert not isinstance(text, str)
    def helper(text: bytes):
        assert text
        node, j, all_j = tcs.vtrie, 0, []
        tree = {}
        while node := node.get(text[j]):
            if j >= len(text) - 1:
                break
            j += 1
            if None in node:
                tree[node[None]] = helper(text[j:])
    
            if j >= len(text) - 1:
                break
        else:
            print(f"check suffixes: {j}")

        extensions = tcs._get_walk_cached(text)
        if len(extensions) > 0:
            tree[None] = torch.from_numpy(extensions) 
            if device is not None:
                tree[None] = tree[None].to(device)
        return tree
        
    return helper(text)
    
byte_cover("女".encode())

def next_char_prob(tcs: TextConditionedSampler, prompt: str, char: str):
    if tcs.btok.normalizer is not None:
        prompt = tcs.btok.normalizer.normalize_str(prompt)

    trunk = tcs.tokenizer.encode(prompt)
    if not trunk or trunk[0] != tcs.bos:
        trunk = [tcs.bos, *trunk]

    branches = byte_cover(char.encode())
    full_tree = branches
    for tid in reversed(trunk):
        full_tree = {tid: full_tree}

    batch_trees = [[full_tree]]
    output, backref_tree = tcs.tree_inference_batched2([[full_tree]])
    prob_trees = tcs._compute_prob_trees2(batch_trees, output, backref_tree)
    lptrunk, lpbranches = prob_trees[0][0]
    return sum(lp for _, lp in lptrunk[len(trunk) :]) + tcs._integrate_prob_tree(
        lpbranches
    )

correct = 0
total = 0
text = "REDACTED/pretokenization/data/olmo2_shuffle/0000.jsonl.zstd"
N = 100_000
for i, doc in enumerate(tqdm(read_jsonl_zstd(text), total=N)):
    prompt = doc['text']
    idx = random.randint(0, min(len(prompt)-1, 1000))
    try:
        correct += next_char_prob(tcs, prompt[:idx], prompt[idx])
        # predict_next_char(tcs, prompt[:idx]) == prompt[idx]
        total += 1
    except UnicodeDecodeError:
        pass
    if i == N:
        break