from syncode.parsers.grammars import Grammar
from syncode import SyncodeLogitsProcessor
import xgrammar as xgr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
import time
from pathlib import Path
from truncproof.logits import GrammarLogitsProcessor
from truncproof.generation_utils import mct_search

class FallbackSyncodeLogitsProcessor(SyncodeLogitsProcessor):
    """Syncode does not support beamsearch because of its behavior of incremental parser.
    So use this processor for beamsearch. It disables incremental process.
    """
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        for idx in range(len(input_ids)):
            self.grammar_engine.inc_parser.reset()
            scores[idx:idx+1] = super().__call__(input_ids[idx:idx+1], scores[idx:idx+1])
        return scores

text_grammar = Path("c_subset.lark").read_text()
repo = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype="auto")
model = model.to("cuda")

@torch.no_grad()
def ppl(output_ids) -> float:
    logits = model(output_ids, return_dict=True).logits[:, :-1, :].log_softmax(-1)
    selected_logits = logits.take_along_dim(output_ids[:, 1:].unsqueeze(-1), dim=-1).squeeze(-1)
    return torch.exp(-selected_logits.mean()).item()

print("### C generation ###")
prompt_info = [{
    "content": "Write a C function that sums up 1 to N. Only output the code without codeblock quotations.\n",
    "role": "user"
}]
print(tokenizer.apply_chat_template(prompt_info, tokenize=False))

input_ids = torch.LongTensor(
    tokenizer.apply_chat_template(prompt_info, tokenize=True, add_generation_prompt=True)
).to(model.device).unsqueeze(0)
input_length = input_ids.size(1)
print("prompt tokens:", input_length)

with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_new_tokens=400,
        do_sample=False,
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("generated tokens:", len(output_ids.tolist()[0]) - input_length)
print("PPL:", ppl(output_ids))

print("==== TruncProof 400 ====")
logiproc = GrammarLogitsProcessor(input_length + 400, text_grammar, "WS", tokenizer, model.config.eos_token_id, input_length)
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_new_tokens=400,
        do_sample=False,
        logits_processor=[logiproc],
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))

print("==== SynCode 400 ====")
syncode = FallbackSyncodeLogitsProcessor(Grammar(text_grammar), tokenizer)
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_new_tokens=400,
        do_sample=False,
        logits_processor=[syncode],
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))


print("==== TruncProof 40 ====")
logiproc = GrammarLogitsProcessor(input_length + 40, text_grammar, "WS", tokenizer, model.config.eos_token_id, input_length)
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_new_tokens=40,
        do_sample=False,
        logits_processor=[logiproc],
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))


print("==== SynCode 40 ====")
syncode = FallbackSyncodeLogitsProcessor(Grammar(text_grammar), tokenizer)
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_new_tokens=40,
        do_sample=False,
        logits_processor=[syncode],
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))


print("==== TruncProof 40 MCTS ====")
logiproc = GrammarLogitsProcessor(input_length + 40, text_grammar, "WS", tokenizer, model.config.eos_token_id, input_length)
with torch.no_grad():
    output_ids = mct_search(
        model,
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        tokenizer = tokenizer,

        max_new_tokens=40,
        do_sample=False,
        logits_processor=logiproc,
        n_trials=20,
        puct_coeff=5,
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))


print("==== SynCode 40 MCTS ====")
syncode = FallbackSyncodeLogitsProcessor(Grammar(text_grammar), tokenizer)
with torch.no_grad():
    output_ids = mct_search(
        model,
        input_ids,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        tokenizer = tokenizer,

        max_new_tokens=40,
        do_sample=False,
        logits_processor=syncode,
        n_trials=20,
        puct_coeff=5,
    )
output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
print(output)
print("PPL:", ppl(output_ids))

