import json
from pathlib import Path

import torch
torch.set_grad_enabled(False)
import torch.nn.functional as F
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams, TokensPrompt, TextPrompt
from vllm.sampling_params import BeamSearchParams
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from typing import Union, List
import itertools

import dp_tokenization as dpt
from utils import bytes_to_unicode
from text_conditioning import *
from tqdm.auto import tqdm
from statsmodels.stats.proportion import proportion_confint
from utils import read_jsonl_gz
import random
B = bytes_to_unicode()
Brev = {v:k for k,v in B.items()}

model_dir = Path("Qwen/Qwen3-1.7B-Base")

llm = LLM(str(model_dir), enforce_eager=True, max_logprobs=1000)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tcs = TextConditionedSampler(model_dir, skip_model=True)

class ByteConditionedLogitProcessor:
    def __init__(self, tcs, byte_seq):
        self.tcs = tcs
        self.byte_seq = byte_seq
        # self.valid_logprobs = []

    def __call__(self, past, logits):
        # print(tcs.tokenizer.decode(past), past)
        logprobs = F.log_softmax(logits, 0)
        mask = torch.ones_like(logits) * -torch.inf
        pointer = self.tcs.vtrie

        string = b''.join(self.tcs.vrev[p] for p in past)
        index = sum(len(self.tcs.vrev[p]) for p in past)

        try:
            if index > len(self.byte_seq):
                suffix = string[len(self.byte_seq):].decode()
                mask[tcs.eos] = 0
                # print(f"got suffix {self.byte_seq!r} and {string!r}, exiting...")
                return logits + mask
        except UnicodeDecodeError:
            pass

        # 4 bytes is the longest unicode character
        if index > len(self.byte_seq) + 4:
            # print(f"exiting on length... {index, len(self.byte_seq)}")
            # print(f"past: {past}")
            mask[tcs.eos] = 0
            return logits + mask
            
        cur_valid_logprobs = {}

        for i in range(index, len(self.byte_seq)):
            char = self.byte_seq[i]
            if pointer := pointer.get(char):
                if None in pointer:
                    tid = pointer[None]
                    cur_valid_logprobs[tid] = logprobs[tid]
                    mask[tid] = 0

            else:
                break
        else:
            T = torch.from_numpy(self.tcs._get_walk_cached(self.byte_seq[index:]))
            mask[T] = 0

        return logits + mask


def predict(tcs: TextConditionedSampler, prompt: str, backtrack=1):
    tokens = tokenizer.encode(prompt, add_special_tokens=False)
    prefix = tokens[:-backtrack]
    if not prefix or prefix[0] != tcs.bos:
        prefix = [tcs.bos, *prefix]
    suffix = b''.join(tcs.vrev[tid] for tid in tokens[-backtrack:])
    bclp = ByteConditionedLogitProcessor(tcs, suffix)
    # print(prefix, suffix)
    out = llm.generate(
        [TokensPrompt(prompt_token_ids=prefix)],
        SamplingParams(
            logits_processors=[bclp],
            logprobs=0,
            temperature=0,
            max_tokens=100
        ),
        use_tqdm=False
    )
    buf = b''
    suffix_cost = 0
    # print(suffix, out)
    # print(out[0].outputs[0])
    for tid in out[0].outputs[0].token_ids:
        suffix_cost += 1
        if tid in (tcs.eos, tcs.bos, tcs.pad):
            break
        buf += tcs.vrev[tid]
        # print(tid, buf.decode())
        if len(buf) > len(suffix):
            break
            
    # print(out)
    try:
        out_bytes = b''.join(tcs.vrev.get(tid, b'') for tid in out[0].outputs[0].token_ids)[len(suffix):]
        orig_out_bytes = out_bytes
        # print(f"out: {suffix!r} {out_bytes!r}")
        while out_bytes:
            try:
                out_str = out_bytes.decode()
                break
            except UnicodeDecodeError:
                out_bytes = out_bytes[:-1]

        if not out_bytes:
            print(f"no valid suffix of {orig_out_bytes!r}")
            return ' ', suffix_cost
        return out_str[0], suffix_cost
    except IndexError:
        print(suffix)
        print(out)
        raise

correct = 0
cost = 0
total = 0
text = "REDACTED/madlad_shuffle/zh/zh_clean_0000.jsonl.gz"
N = 100_000
for i, doc in enumerate(tqdm(read_jsonl_gz(text), total=N)):
    prompt = doc['text']
    idx = random.randint(0, min(len(prompt)-1, 500))
    # try:
    pred, dcost = predict(tcs, prompt[:idx], backtrack=4) 
    correct += pred == prompt[idx]
    cost += dcost
    total += 1
    # except UnicodeDecodeError:
    #     pass
    if i == N:
        break