import json
from pathlib import Path

import torch
torch.set_grad_enabled(False)
import torch.nn.functional as F
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, TokensPrompt, TextPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from typing import Union, List
import itertools

import dp_tokenization as dpt
from utils import bytes_to_unicode
from text_conditioning import *
from tqdm.auto import tqdm
from statsmodels.stats.proportion import proportion_confint

B = bytes_to_unicode()
Brev = {v:k for k,v in B.items()}

model_dir = Path("allenai/OLMo-2-0425-1B")
tcs = TextConditionedSampler(model_dir)

def next_token_prob(tcs: TextConditionedSampler, prompt, next_tid):
    if not prompt or prompt[0] != tcs.bos:
        prompt = [tcs.bos, *prompt]
    return torch.log_softmax(
        tcs.model.forward(
            torch.tensor([prompt], device=tcs.model.device, dtype=torch.long)
        ).logits[0, -1],
        0,
    )[next_tid].item()


next_token_prob(
    tcs, tcs.tokenizer.encode("This is a"), tcs.tokenizer.encode(" test")[0]
)

def clip_tokens_by_chars(tcs, tokens, length):
    char_count = 0
    buf = b''
    def push_byte(b):
        nonlocal buf, char_count
        buf += bytes([b])
        try:
            char = buf.decode()
            char_count += 1
            buf = b''
        except UnicodeDecodeError:
            pass

    for i, tid in enumerate(tokens):
        for b in tcs.vrev[tid]:
            push_byte(b)
        if char_count >= length:
            return i+1
    return len(tokens)

from utils import read_jsonl_zstd
import random
correct = 0
total = 0
total_chars = 0
total_bytes = 0
scored_bytes = 0
text = "REDACTED/pretokenization/data/olmo2_shuffle/0000.jsonl.zstd"
N = 100_000
for i, doc in enumerate(tqdm(read_jsonl_zstd(text), total=N)):
    prompt = doc['text']
    total_bytes += len(prompt.encode())
    total_chars += len(prompt)
    tokens = tcs.tokenizer.encode(prompt)
    max_token = clip_tokens_by_chars(tcs, tokens, 1000)
    if max_token == len(tokens):
        max_token -= 1
    idx = random.randint(0, max_token)
    try:
        correct += next_token_prob(tcs, tokens[:idx], tokens[idx])
        scored_bytes += len(tcs.vrev[tokens[idx]])
    except UnicodeDecodeError:
        pass
    if i == N:
        break