import os
import torch
from time import time
from safetensors import safe_open
from argparse import ArgumentParser
from fast_compression import decode
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.logits_process import LogitsProcessor

from configs import Config
from paper_prompts import PROMPTS
from inference.torch.model import LZWModel
from inference.tokenizer import LZWTokenizer
from utils import get_device, unflatten, dataclass_from_dict, adapt_model, setup_seed

parser = ArgumentParser()
parser.add_argument("--original", action="store_true", default=False)
parser.add_argument("--prompt-length", type=str, choices=PROMPTS.keys(), required=True)
args = parser.parse_args()

setup_seed()
device = get_device()
torch.set_float32_matmul_precision("high")

local_folder_path = snapshot_download(
    repo_id="XXX/online-zip2zip-2",
    allow_patterns=["QQsH/model_9500.safetensors"],
)
adapter_path = os.path.join(local_folder_path, "QQsH/model_9500.safetensors")

metadata = {}
adapter_state_dict = {}
with safe_open(adapter_path, framework="pt", device=device) as f:
    for k in f.keys():
        adapter_state_dict[k.replace("model._orig_mod.", "")] = f.get_tensor(k)

    for k, v in f.metadata().items():
        metadata[k] = v

dict_config = unflatten(metadata)["config"]
del dict_config["early_stopping_patience"]
del dict_config["lora"]["use_rslora"]
del dict_config["epochs"]
config = dataclass_from_dict(Config, dict_config)

original_tokenizer = AutoTokenizer.from_pretrained(
    config.pretrained_tokenizer_name_or_path,
)

if args.original:
    model = AutoModelForCausalLM.from_pretrained(
        config.pretrained_model_name_or_path,
        torch_dtype=config.dtype,
    ).to(device)

    model.eval()

    tokenizer = original_tokenizer
else:
    tokenizer = LZWTokenizer.from_pretrained(
        config.pretrained_tokenizer_name_or_path,
        max_codebook_size=config.extra_vocab_size,
        max_subtokens=config.compression.max_subtokens,
        disabled_ids=list(original_tokenizer.get_added_vocab().values()),
    )

    model = LZWModel.from_config(config, device, tokenizer.pad_token_id)

    adapt_model(model, config, merge=True)

    model._load_state_dict(adapter_state_dict)

    model._to(config.dtype)
    model._eval()
    # model.compile()


class TimeLogitsProcessor(LogitsProcessor):
    def __init__(self):
        super().__init__()
        self.timestamps = []

    def __call__(
        self, _: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        self.add_timestamp()
        return scores

    def add_timestamp(self):
        self.timestamps.append(time())


messages = [
    {
        "role": "user",
        "content": PROMPTS[args.prompt_length],
    }
]

prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

start_tokenize = time()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
tokenize_time = time() - start_tokenize

input_length = inputs["input_ids"].shape[1]

time_logits_processor = TimeLogitsProcessor()
time_logits_processor.add_timestamp()

outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.6,
    max_new_tokens=256,
    cache_implementation="dynamic",
    logits_processor=[time_logits_processor],
)

time_logits_processor.add_timestamp()

generated_outputs = outputs[:, input_length:]
generated_outputs_length = generated_outputs.shape[1]
print(tokenizer.batch_decode(generated_outputs, skip_special_tokens=False)[0])

prefill_time = time_logits_processor.timestamps[1] - time_logits_processor.timestamps[0]
generation_time = (
    time_logits_processor.timestamps[-1] - time_logits_processor.timestamps[1]
)

prompt_tps = input_length / prefill_time
generation_tps = generated_outputs_length / generation_time

print("=" * 10)
print("Compressed:")
print(f"Prompt: {input_length} tokens, {prompt_tps:.3f} tokens-per-sec")
print(
    f"Generation: {generated_outputs_length} tokens, {generation_tps:.3f} tokens-per-sec"
)

original_input_length = input_length

if not args.original:
    original_inputs = decode(
        inputs["input_ids"][0].tolist(),
        config.initial_vocab_size,
        config.extra_vocab_size,
        config.compression.max_subtokens,
    )
    original_input_length = len(original_inputs)

    original_generated_outputs = decode(
        generated_outputs[0].tolist(),
        config.initial_vocab_size,
        config.extra_vocab_size,
        config.compression.max_subtokens,
    )

    original_generated_outputs_length = len(original_generated_outputs)
    print("=" * 10)
    print("Uncompressed:")
    print(
        f"Prompt: {original_input_length} tokens, {original_input_length * (prompt_tps / input_length):.3f} tokens-per-sec"
    )
    print(
        f"Generation: {original_generated_outputs_length} tokens, {original_generated_outputs_length * (generation_tps / generated_outputs_length):.3f} tokens-per-sec"
    )

print("=" * 10)
print(
    f"PT: {prefill_time:.3f}s, GT: {generation_time:.3f}s, TPS: {prefill_time + generation_time:.3f}, TT: {original_input_length / tokenize_time:.3f} tokens-per-sec"
)
