from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List
from tqdm import tqdm
import time
import pickle
from pathlib import Path
import json
from dataclasses import dataclass
from jsonschema import validate, ValidationError
import pandas as pd
from argparse import ArgumentParser

from datasets import load_dataset

from truncproof.logits import GrammarLogitsProcessor, TokenEfficiencyProcessor, GroupedBeamTermsLogitsProcessor
from truncproof.generation_utils import mct_search

GENPROMPT = False
parser = ArgumentParser()
parser.add_argument("model", choices=["gemma", "llama"])
parser.add_argument("--hard_only", action="store_true")
parser.add_argument("--limit", type=float, default=1.1, help="coeff of token limit in hard settings")
args = parser.parse_args()

repo_map = {
    "gemma": "google/gemma-2-2b-it",
    "llama": "meta-llama/Llama-2-7b-chat-hf",
}
tokenizer_options = {
    "gemma": {},
    "llama": dict(use_fast=False, legacy=False),
}

repo = repo_map[args.model]
tokenizer = AutoTokenizer.from_pretrained(repo, **tokenizer_options[args.model])

repo_stem = repo.split("/")[-1]
limit_coeff = args.limit
OUTPUT = Path(f"outputs_{repo_stem}_{limit_coeff:.2f}_compact")
assert OUTPUT.exists(), "We want to append experiments to the existing folder"

text_grammar = r"""
// Based on RFC 8259
?start: value

_BEGIN_ARRAY:     /[ \t\f\r\n]*\[[ \t\f\r\n]*/
_BEGIN_OBJECT:    /[ \t\f\r\n]*\{[ \t\f\r\n]*/
_END_ARRAY:       /[ \t\f\r\n]*\][ \t\f\r\n]*/
_END_OBJECT:      /[ \t\f\r\n]*\}[ \t\f\r\n]*/
_NAME_SEPARATOR:  /[ \t\f\r\n]*:[ \t\f\r\n]*/
_VALUE_SEPARATOR: /[ \t\f\r\n]*,[ \t\f\r\n]*/

?value: object
| array
| STRING
| number
| "true"             -> true
| "false"            -> false
| "null"             -> null

object: _BEGIN_OBJECT [member (_VALUE_SEPARATOR member)*] _END_OBJECT
member: STRING _NAME_SEPARATOR value
array : _BEGIN_ARRAY [value (_VALUE_SEPARATOR value)*] _END_ARRAY

number: MINUS? INT FRAC? EXP?
MINUS: "-"
INT: "0" | ("1".."9") DIGIT*
DIGIT: "0".."9"
FRAC: "." DIGIT+
EXP: ("e"|"E") ["+"|"-"] DIGIT+

STRING: /"([^"\\\x00-\x19]|\\["\\\/bfnrt]|\\u[0-9A-Fa-f]{4})*"/
"""

model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype="auto")
model = model.to("cuda")

ds = load_dataset("NousResearch/json-mode-eval", split = "train")
problems = [{
    "prompt": [{
        'content':
            problem['prompt'][0]['content']
            + problem['prompt'][1]['content']
            + "\nOnly output JSON. Eliminate white spaces and keep the output as compact as possible.\n",
        'role': 'user'
    }],
    "schema": problem["schema"],
    "ground_truth": problem["completion"],
    "gt_tokens": len(tokenizer.encode(problem["completion"], add_special_tokens=False)),
} for problem in ds]

#problems = problems[:3]


print("#### sample input ####")
print(tokenizer.apply_chat_template(problems[0]["prompt"], tokenize=False))
print("#### sample ground truth ####")
print(problems[0]["ground_truth"])
print()


@dataclass
class EvalResult:
    syntax: int = 0
    schema: int = 0
    exact_match: int = 0


def validate_one(problem: dict, pred: str) -> EvalResult:
    schema = json.loads(problem["schema"])
    try:
        parsed = json.loads(pred)
    except json.decoder.JSONDecodeError:
        return EvalResult()  # syntax error
    if not isinstance(parsed, dict):
        return EvalResult(1)  # schema error
    try:
        validate(instance=parsed, schema=schema)
    except ValidationError as e:
        #print(e)
        return EvalResult(1)  # schema error
    if parsed == json.loads(problem["ground_truth"]):
        return EvalResult(1, 1, 1)  # exact match
    else:
        return EvalResult(1, 1, 0)  # schema is correct but unmatch

@torch.no_grad()
def ppl(output_ids) -> float:
    logits = model(output_ids, return_dict=True).logits[:, :-1, :].log_softmax(-1)
    selected_logits = logits.take_along_dim(output_ids[:, 1:].unsqueeze(-1), dim=-1).squeeze(-1)
    return torch.exp(-selected_logits.mean()).item()

def predict(kwargs_getter):
    completions = []
    predicted_tokens = []
    perplexity = []
    generation_time = []
    for problem in tqdm(problems):
        input_ids = torch.LongTensor(
            tokenizer.apply_chat_template(problem["prompt"], tokenize=True, add_generation_prompt=GENPROMPT)
        ).to(model.device).unsqueeze(0)  # (batchsize, sequence)
        input_length = input_ids.size(1)
        tic = time.perf_counter()
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                max_new_tokens=200,
                do_sample=False,
                **kwargs_getter(input_ids, 200),
            )
        output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
        tac = time.perf_counter()
        completions.append(output)
        predicted_tokens.append(output_ids.size(1) - input_length)
        perplexity.append(ppl(output_ids))
        generation_time.append(tac - tic)
    return completions, predicted_tokens, perplexity, generation_time

def predict_hard(kwargs_getter):
    completions = []
    predicted_tokens = []
    max_new_tokens_array = []
    perplexity = []
    generation_time = []
    for problem in tqdm(problems):
        input_ids = torch.LongTensor(
            tokenizer.apply_chat_template(problem["prompt"], tokenize=True, add_generation_prompt=GENPROMPT)
        ).to(model.device).unsqueeze(0)  # (batchsize, sequence)
        input_length = input_ids.size(1)
        gt_length = problem["gt_tokens"]
        max_new_tokens = int(gt_length * 1.1)
        tic = time.perf_counter()
        with torch.no_grad():
            #output_ids = model.generate(
            #    input_ids,
            #    max_new_tokens=max_new_tokens,
            #    **kwargs_getter(input_ids, max_new_tokens),
            #)
            output_ids = mct_search(
                model,
                input_ids,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                tokenizer = tokenizer,
                **kwargs_getter(input_ids, max_new_tokens),
            )

        output = tokenizer.decode(output_ids.tolist()[0][input_length:], skip_special_tokens=True)
        tac = time.perf_counter()
        completions.append(output)
        predicted_tokens.append(output_ids.size(1) - input_length)
        max_new_tokens_array.append(max_new_tokens)
        perplexity.append(ppl(output_ids))
        generation_time.append(tac - tic)
    return completions, predicted_tokens, max_new_tokens_array, perplexity, generation_time


def get_proposed_kwargs_mcts(input_ids, max_new_tokens: int):
    logiproc = GrammarLogitsProcessor(input_ids.size(1) + max_new_tokens, text_grammar, "", tokenizer, model.config.eos_token_id, input_ids.size(1), debug=False)
    return {
        "logits_processor": logiproc,
        "n_trials": 20,
        "puct_coeff": 5,
        "max_new_tokens": max_new_tokens
    }


methods = {
    "proposed_mcts": [None, get_proposed_kwargs_mcts],
}

ppls = []
for problem in tqdm(problems):
    input_ids = torch.LongTensor(
        tokenizer.apply_chat_template(problem["prompt"], tokenize=True, add_generation_prompt=GENPROMPT)
    ).to(model.device).unsqueeze(0)  # (batchsize, sequence)
    gt = tokenizer.encode(problem["ground_truth"], add_special_tokens=False, return_tensors="pt").to(model.device)
    output_ids = torch.cat([input_ids, gt], dim=1)
    ppls.append(ppl(output_ids))
pd.DataFrame({"ppl": ppls}).to_csv(OUTPUT / "ground_truth_ppl.csv")

for key, kwargs_getter in methods.items():
    print(key)
    if (not args.hard_only) and kwargs_getter[0] is not None:
        with (OUTPUT / f"{key}_normal.pickle").open("wb") as f:
            completions, predicted_tokens, perplexity, gentime = predict(kwargs_getter[0])
            pickle.dump(completions, f)

        result = []
        for problem, pred in zip(problems, completions):
            result.append(validate_one(problem, pred))
        df = pd.DataFrame(result)
        df["max_new_tokens"] = [200 for _ in problems]
        df["gt_tokens"] = [problem["gt_tokens"] for problem in problems]
        df["predicted_tokens"] = predicted_tokens
        df["prompt"] = [tokenizer.apply_chat_template(problem["prompt"], tokenize=False) for problem in problems]
        df["schema_text"] = [problem["schema"] for problem in problems]
        df["ground_truth"] = [problem["ground_truth"] for problem in problems]
        df["prediction"] = completions
        df["perplexity"] = perplexity
        df["generation_time"] = gentime
        df.to_csv(OUTPUT / f"{key}_normal.csv")

    if kwargs_getter[1] is not None:
        with (OUTPUT / f"{key}_hard.pickle").open("wb") as f:
            completions, predicted_tokens, max_new_tokens, perplexity, gentime = predict_hard(kwargs_getter[1])
            pickle.dump(completions, f)

        result = []
        for problem, pred in zip(problems, completions):
            result.append(validate_one(problem, pred))
        df = pd.DataFrame(result)
        df["max_new_tokens"] = max_new_tokens
        df["gt_tokens"] = [problem["gt_tokens"] for problem in problems]
        df["predicted_tokens"] = predicted_tokens
        df["prompt"] = [tokenizer.apply_chat_template(problem["prompt"], tokenize=False) for problem in problems]
        df["schema_text"] = [problem["schema"] for problem in problems]
        df["ground_truth"] = [problem["ground_truth"] for problem in problems]
        df["prediction"] = completions
        df["perplexity"] = perplexity
        df["generation_time"] = gentime
        df.to_csv(OUTPUT / f"{key}_hard.csv")

