from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread
import pandas as pd

from sentence_transformers import LoggingHandler, SentenceTransformer,CrossEncoder
from sentence_transformers.util import semantic_search
from tqdm import tqdm
import jsonlines
import time
import json
from speculative_asqa_alpaca import rag_drafting_generator_local
import argparse
from argparse import Namespace


retrieve_device = "cuda" if torch.cuda.is_available() else "cpu"

embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5",device=retrieve_device)

draft_model_name="ALPACA_MODEL_PATH"  
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    torch_dtype=torch.float16,
    device_map = "auto",
)
eos_token_id = draft_tokenizer.eos_token_id


def inference(query,ctxs,num_max_new_tokens,m,k,output_bath):
    generated_text = ""
    round_i = 0
    max_rounds=10
    start_time = time.time()
    while(True):
        round_i += 1
        results = rag_drafting_generator_local(
                num_cluster=k,
                num_max_new_token=num_max_new_tokens,
                num_subsets=m,
                embedding_model = embedding_model,
                draft_model=draft_model,
                draft_tokenizer=draft_tokenizer,
                instruction=query,
                docs=ctxs,
                generated_text=generated_text
            )
        responses = results["responses"]
        best_idx = results['best_index']    
        best_answer   = results['best_answer']
        new_tokens = responses[best_idx]["new_tokens"]  
        generated_text = best_answer

        if eos_token_id in new_tokens or round_i > max_rounds:
            end_time = time.time()
            result = {"query": query, "generated_text": generated_text,"time": (end_time - start_time)/m}
            with jsonlines.open(output_bath, mode='a') as writer:  
                writer.write(result)
            break






if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference over dataset")
    parser.add_argument("--num_max_new_tokens", type=int, default=50, help="Maximum new tokens to generate")
    parser.add_argument("--m", type=int, default=5, help="m draft candidates")
    parser.add_argument("--k", type=int, default=5, help="k clusters")
    parser.add_argument("--output_path", type=str, default="output.jsonl", help="output data path")
    parser.add_argument("--input_data_path", type=str, required=True, help="Path to input JSONL data file")
    parser.add_argument("--n", type=int,default = 10, required=True, help="top n documents to retrieve")
    args = parser.parse_args()

    data = []
    with open(args.input_data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    for record in tqdm(data, desc="Processing queries", unit="query"):
        question = record["question"]
        ctxs = record["docs"][0:args.n]
        inference(question,ctxs,args.num_max_new_tokens, args.m, args.k, args.output_path)
        


