from opt_eagle_hybrid.ea_model_hybrid import EaModel
import torch
import os
import json
import re

def generate_summary_prompt(text: str) -> str:
    sentence_endings = list(re.finditer(r'(?<=[.!?])\s+', text))
    
    if sentence_endings:
        last_full_sentence_end = sentence_endings[-1].end()
        truncated_text = text[:last_full_sentence_end].strip()
    else:
        truncated_text = text.strip()
    
    # Build the final prompt with clear markers for the LLM
    prompt = f"\n\nSummarize the following text into a concise summary of less than 500 words.\n### Text:\n\n{truncated_text}\n\n### Summary:\n\n"
    return prompt

############### Load input data ##############
########### BookSum ###########
data_folder = 'data/GovReport'
subdir_path = os.path.join(data_folder)

file_to_data = {}
n = 20
for fname in os.listdir(subdir_path):
    full_path = os.path.join(subdir_path, fname)
    
    records = []
    with open(full_path, "r", encoding="utf-8", errors="replace") as f:
        for i, line in enumerate(f):
            if i >= n:  # Stop reading after n lines
                break
            line = line.strip()
            # Remove potential BOM characters
            line = line.lstrip('\ufeff')
            if not line:
                continue  # Skip blank lines
            try:
                record = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"JSON decode error in file {fname} line {i}: {e}")
                continue
            records.append(record)
    if records:
        length = records[-1]['trunc_length']
        file_to_data[length] = records


############### Load Models #################
# Define the model lists
target_models = [
    "lmsys/vicuna-7b-v1.5-16k",
    "lmsys/longchat-7b-16k"
]

approx_models = [
    "eagle_ckpt/EAGLE_vicuna_7B_16k",
    'eagle_ckpt/EAGLE_longchat_7B_16k'
]

# Select model paths from the lists
base_model_path = target_models[0]   # "lmsys/vicuna-7b-v1.5-16k"
draft_model_path = approx_models[0]

model = EaModel.from_pretrained(
    base_model_path=base_model_path,
    ea_model_path=draft_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)
model.eval()
tokenizer = model.tokenizer

##################### Single Inference #####################
input_text = file_to_data[1024*16][0]['text']
input_ids = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens=True).to(model.base_model.device)

max_depth = 7
total_nodes = 50

temperature = 0
max_new_tokens = 256

output_ids, results = model.eagenerate(
    input_ids,
    temperature=0,
    max_new_tokens=256,
    nodes=total_nodes,
    threshold=0.7,
    max_depth=max_depth,
    output_result_line=True,
    print_input=False,
    verbose=True,
    use_streamingLLM_cache=False,
    sink_size=16,
    recent_size=1000,
    use_retrieval_cache=False,
    retrieval_verbose=False,
    retrieval_chunk_size=32,
    retrieve_top_k=64,
    retrieve_every_n_steps=4,
    print_draft_tree=False,
    test_generate_cache=False,
    show_time=True,
    measure_time=True,
    target_use_flash_prefill = False,
    target_use_hybrid_tree_attn = True,
    draft_use_flash_prefill = False
)

# output = model.tokenizer.decode(output_ids[0])
# print(output)