from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread
import pandas as pd

from sentence_transformers import LoggingHandler, SentenceTransformer,CrossEncoder
from sentence_transformers.util import semantic_search
from tqdm import tqdm
import jsonlines
import time
import json
from speculative_pub_mis import rag_drafting_generator_local
import argparse
from argparse import Namespace


retrieve_device = "cuda" if torch.cuda.is_available() else "cpu"

embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5",device=retrieve_device)


draft_model_name="mistralai/Mistral-7B-Instruct-v0.1"
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_name,
    torch_dtype=torch.float16,
    device_map = "auto"  
)




def inference(query,ctxs,answer,label,m,k,output_bath,num_max_new_tokens):
    results = rag_drafting_generator_local(
            num_cluster=k,
            num_max_new_token=num_max_new_tokens,
            num_subsets=m,
            embedding_model = embedding_model,
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            instruction=query,
            docs=ctxs,
        )
    responses = results["responses"]
    best_answer   = results['best_answer']
    select_time = results['select_time']
    max_time = max(responses, key=lambda x: x["time"])["time"]


    result = {"query": query, "generated_text": best_answer, "time":(max_time+select_time)/m, "answer": answer, "label": label}
    with jsonlines.open(output_bath, mode='a') as writer: 
        writer.write(result)





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference over dataset")
    parser.add_argument("--num_max_new_tokens", type=int, default=50, help="Maximum new tokens to generate")
    parser.add_argument("--n", type=int, default=10, help="top_n documents")
    parser.add_argument("--m", type=int, default=5, help="m draft candidates")
    parser.add_argument("--k", type=int, default=5, help="k clusters")
    parser.add_argument("--output_path", type=str, default="output.jsonl", help="output data path")
    parser.add_argument("--input_data_path", type=str, required=True, help="Path to input JSONL data file")
    args = parser.parse_args()

    data = []
    with open(args.input_data_path, 'r', encoding='utf-8') as f:
        for line in f:
            record = json.loads(line)
            data.append(record)

    for record in tqdm(data, desc="Processing queries", unit="query"):
        claim_text = record["claim"]
        answer = record["answers"][0]
        label = record["label"]
        ctxs = record["ctxs"][:args.n]
        inference(claim_text,ctxs,answer,label,args.m, args.k, args.output_path,args.num_max_new_tokens)
        


