from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread
import pandas as pd

from sentence_transformers import LoggingHandler, SentenceTransformer
from sentence_transformers.util import semantic_search
from tqdm import tqdm
import jsonlines
import time
import json
from speculative_online import rag_drafting_generator_local
import argparse
from argparse import Namespace

from passage_retrieval import Retriever  

retrieve_device = "cuda:0"

embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5",device=retrieve_device)

draft_model_name="mistralai/Mistral-7B-Instruct-v0.1"
draft_device="cuda:0"
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    torch_dtype=torch.float16, 
)
draft_model=draft_model.to(draft_device)



def init_retriever(topn):
    args = Namespace(
        query=None,  # The query is provided at runtime during the actual invocation.
        passages="./data/psgs_w100.tsv",  # The path to the original documents.
        passages_embeddings="./data/wikipedia_embeddings/*",  
        output_dir="path/to/output",  # output directory 
        device=retrieve_device,  
        n_docs=topn,  # The number of documents returned each time.
        validation_workers=32,
        per_gpu_batch_size=64,
        save_or_load_index=False,  # Whether to save or load a previously built index.
        model_name_or_path="facebook/contriever-msmarco",  
        no_fp16=False,  
        question_maxlength=512,
        indexing_batch_size=1000000,
        projection_size=768,
        n_subquantizers=0,
        n_bits=8,
        lang=["en"],
        dataset="none",
        lowercase=True,
        normalize_text=True
    )
    
    retriever = Retriever(args)
    retriever.setup_retriever()
    return retriever

retrieve_result_queue = Queue()
need_retrieve_queue = Queue()
max_rounds = 10

# retrieval thread
def retrieval_thread(need_retrieve_queue):
    while True:
        query = need_retrieve_queue.get()
        if query is None:  
            break
        retrieved_chunks = retriever.search_document(query, top_n=retriever.args.n_docs)
        retrieve_result_queue.put(retrieved_chunks)


# inference thread
def inference_thread(query,num_max_new_tokens,num_subsets,num_clusters,output_path):
    
    generation_config = {
    "do_sample": False,                            
    "max_new_tokens":num_max_new_tokens,
    "repetition_penalty": 1.1,  
    "eos_token_id": draft_tokenizer.eos_token_id,  
}
    
    retrieve_result_last_time = ""
    generated_text = ""
    eos_token_id = draft_tokenizer.eos_token_id
    total_cost_time=0
    round_i = 0
    max_rounds = 10

    while True:
        round_i += 1
        chunk_start_time = time.time()

        if retrieve_result_queue.empty():
            print("Take last time retrieve")
            r = retrieve_result_last_time
        else:
            r = retrieve_result_queue.get()
            retrieve_result_last_time = r
            
        if r =="":
            r = retriever.search_document(query, top_n=retriever.args.n_docs)
            retrieve_result_last_time = r
            context = f"""<s>[INST]
                        ## Instruction: Write a clear, informative, and balanced answer to the following question, based on the documents below.
                        If the question is ambiguous or has multiple interpretations, explain the possible answers based on the context provided.

                        ## Documents:
                        {r}

                        ## Question:
                        {query}

                        ## Response:
                        [/INST]{generated_text}
                        """
            input_ids = draft_tokenizer(context, return_tensors="pt")["input_ids"].to(draft_device)
            with torch.no_grad():
                outputs = draft_model.generate(input_ids, **generation_config)
            new_tokens = outputs[0][len(input_ids[0]):]  
            new_text = draft_tokenizer.decode(new_tokens, skip_special_tokens=True)
            generated_text += new_text
            chunk_end_time=time.time()
            total_cost_time = total_cost_time + chunk_end_time - chunk_start_time
        
        else:    
            results = rag_drafting_generator_local(
                num_subsets=num_subsets,
                num_clusters=num_clusters,
                num_max_new_tokens=num_max_new_tokens,
                embedding_model = embedding_model,
                draft_model=draft_model,
                draft_tokenizer=draft_tokenizer,
                instruction=query,
                docs=r,
                generated_text=generated_text
            )
            responses = results["responses"]
            best_idx = results['best_index']    
            best_answer   = results['best_answer']
            select_time = results['select_time']
            max_time = max(responses, key=lambda x: x["time"])["time"]
            # compute the time of other responses
            chunk_other_time_sum = sum(resp["time"] for resp in responses if resp["time"] != max_time)
            new_tokens = responses[best_idx]["new_tokens"]  
            generated_text = best_answer
            chunk_end_time=time.time()
            total_cost_time = total_cost_time + chunk_end_time - chunk_start_time - chunk_other_time_sum + select_time


        # check eos_token_id
        if eos_token_id in new_tokens or round_i > max_rounds:
            result = {"query": query, "generated_text": generated_text, "time":total_cost_time}
            with jsonlines.open(output_path, mode='a') as writer: 
                writer.write(result)
            break
        else:
            need_retrieve_queue.put(query+generated_text)
            continue




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference over dataset")
    parser.add_argument("--num_max_new_tokens", type=int, default=50, help="Maximum new tokens to generate")
    parser.add_argument("--n", type=int, default=10, help="top_n documents")
    parser.add_argument("--m", type=int, default=5, help="m draft candidates")
    parser.add_argument("--k", type=int, default=5, help="k clusters")
    parser.add_argument("--output_path", type=str, default="output.jsonl", help="output data path")
    parser.add_argument("--input_data_path", type=str, required=True, default = "./data/asqa_eval_gtr_top100.json", help="Path to input JSONL data file")
    args = parser.parse_args()


    retriever = init_retriever(args.n)

    retrieve_result_queue = Queue()
    need_retrieve_queue = Queue()


    with open(args.input_data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    queries = [item["question"] for item in data if "question" in item]

        
    for query in tqdm(queries, desc="Processing queries", unit="query"):
        #Clear any last round remains.
        while not need_retrieve_queue.empty():
            need_retrieve_queue.get()
        while not retrieve_result_queue.empty():
            retrieve_result_queue.get()

        #Initialize the retrieval queue. 
        need_retrieve_queue.put(query)
        
        #Threads start
        retrieve_worker = Thread(target=retrieval_thread, args=(need_retrieve_queue,))
        inference_worker = Thread(target=inference_thread, args=(query,args.num_max_new_tokens,args.m,args.k,args.output_path))
        retrieve_worker.start()
        inference_worker.start()

        # Wait for the threads to complete.
        inference_worker.join()
        if not inference_worker.is_alive():
            need_retrieve_queue.put(None)
        retrieve_worker.join()


