from opt_classic_hybrid.model_hybrid import SPModel
from torch.cuda import empty_cache
import torch
import os
import json
import re
from accelerate import Accelerator 

def generate_summary_prompt(text: str) -> str:
    sentence_endings = list(re.finditer(r'(?<=[.!?])\s+', text))
    
    if sentence_endings:
        last_full_sentence_end = sentence_endings[-1].end()
        truncated_text = text[:last_full_sentence_end].strip()
    else:
        truncated_text = text.strip()
    
    # Build the final prompt with clear markers for the LLM
    prompt = f"\n\nSummarize the following text into a summary of less than 800 words.\n### Text:\n\n{truncated_text}\n\n### Summary:\n\n"
    return prompt

############### Load input data ##############
########### BookSum ###########
data_folder = 'data/BookSum'
# data_folder = 'data/GovReport'
subdir_path = os.path.join(data_folder)

file_to_data = {}
n = 20
for fname in os.listdir(subdir_path):
    full_path = os.path.join(subdir_path, fname)
    
    records = []
    with open(full_path, "r", encoding="utf-8", errors="replace") as f:
        for i, line in enumerate(f):
            if i >= n:  # Stop reading after n lines
                break
            line = line.strip()
            # Remove potential BOM characters
            line = line.lstrip('\ufeff')
            if not line:
                continue  # Skip blank lines
            try:
                record = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"JSON decode error in file {fname} line {i}: {e}")
                continue
            records.append(record)
    if records:
        length = records[-1]['trunc_length']
        file_to_data[length] = records

############### Load Models #################
# Define the model lists
target_models = [
    "lmsys/vicuna-7b-v1.5-16k",
    "lmsys/vicuna-13b-v1.5-16k"
]

approx_models = [
    "double7/vicuna-68m",
    "JackFram/llama-68m"
]

# Select model paths from the lists
base_model_path = target_models[1]   # "lmsys/vicuna-7b-v1.5-16k"
draft_model_path = approx_models[1]

model = SPModel.from_pretrained(
    base_model_path=base_model_path,
    draft_model_path=draft_model_path,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

model.eval()
tokenizer = model.tokenizer

accelerator = Accelerator()
model, tokenizer = accelerator.prepare(model, tokenizer)
##################### Single Inference #####################
input_text = file_to_data[1024*1][-5]['text']
# input_text = generate_summary_prompt(input_text)
input_ids = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens=True).to(accelerator.device)

# input_ids = input_ids[:, :500]

max_depth = 10
total_nodes = 50

temperature = 0
max_new_tokens = 256


output_ids, results = model.spgenerate(
    input_ids,
    temperature=temperature,            
    max_new_tokens=max_new_tokens,      
    nodes=total_nodes,                 
    threshold=0.7,            
    max_depth=max_depth,
    output_result_line=True,
    print_input=False,
    verbose=True,
    use_streamingLLM_cache=False,
    sink_size=16,
    recent_size=512,
    use_retrieval_cache=True,
    retrieval_verbose=False,
    retrieval_chunk_size = 32,
    retrieve_top_k = 32,
    print_draft_tree=False,
    test_generate_cache=False,
    show_time=True,
    measure_time=True,
    retrieve_every_n_steps=4,
    target_use_flash_prefill = True,
    target_use_hybrid_tree_attn = False,
    draft_use_flash_prefill = True
)
