import logging

import tqdm
from datasets import load_dataset
from more_itertools import batched
from concurrent.futures import ThreadPoolExecutor, as_completed
import math

from vllm import LLM, SamplingParams
from openai import OpenAI
import json

if __name__ == "__main__":

    logging.basicConfig(level=logging.CRITICAL)
    logging.getLogger("httpx").setLevel(logging.CRITICAL)
    # logging.basicConfig(level=logging.DEBUG)

    MODEL = "google/gemma-2-2b-it"
    DEVICE = "cuda"
    LOG_EVERY = 20
    BATCH_SIZE = 32
    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:8000/v1"
    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    output_file = open("doc_id_to_query_short_gemma.jsonl", "w")
    train_dataset = load_dataset("irds/msmarco-passage", "docs")
    num_batches = int(math.ceil(len(train_dataset) / BATCH_SIZE))
    bar_context_manager = tqdm.tqdm(total=num_batches, desc="Training")
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    batch_count = 0
    SYSTEM_PROMPT = "You are an expert search engine query generator. Your task is to create a concise and effective query that could be used to retrieve a specific document provided by the user. The query should consist of key terms or phrases that are highly relevant to the content of the document. Your response must include only the generated query and nothing else. Document: "

    def make_post_request(document):
        messages_batch = [
            {"role": "user", "content": SYSTEM_PROMPT + document},
        ]
        return client.chat.completions.create(
            model=MODEL,
            messages=messages_batch,
            max_tokens=30,
        )

    with bar_context_manager as bar:  # mlf_context_manager:
        for batch_samples in batched(train_dataset, BATCH_SIZE):
            documents = [doc["text"] for doc in batch_samples]
            doc_id = [doc["doc_id"] for doc in batch_samples]
            current_batch_size = len(batch_samples)

            responses = [None] * current_batch_size
            with ThreadPoolExecutor(max_workers=32) as executor:
                future_to_batch = {
                    executor.submit(make_post_request, d): idx
                    for idx, d in enumerate(documents)
                }
                for future in as_completed(future_to_batch):
                    batch = future_to_batch[future]
                    result = future.result()
                    responses[batch] = result.choices[0].message.content

            for i in range(current_batch_size):
                output_file.write(
                    json.dumps(
                        {
                            "id": doc_id[i],
                            "query": responses[i],
                        }
                    )
                )

            if batch_count % LOG_EVERY == 0:
                bar.write(f"DOC: {documents[0]}")
                bar.write(f"QUERY: {responses[0]}")

            batch_count += 1
            bar.update(1)
