import asyncio
import csv
import logging
import os
import re
import time
import warnings

from datasets import load_dataset
from openai import OpenAI, AsyncOpenAI
from rank_bm25 import BM25Okapi
from tokenizers import Tokenizer
from tokenizers.models import WordPiece

from core.llm import LLM
from core.messages import Message, Role, merge_messages
from evaluation.utils import async_wrapper, get_bm25_context, get_prompt_context, get_gt_answer


def main(
    base: str = "llama3-70b-instruct",
    vllm_hostname: str = "",
    temperature: float = 0.1,
    dataset: str = "nyt",
    n_questions: int = 1000,
):
    if base == "llama3-70b-instruct":
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
        model_id = "meta-llama/Meta-Llama-3-70B-Instruct"
    else:
        raise NotImplementedError(f"Unknown base {base}")

    dataset_name = dataset
    output_directory_name = dataset
    dataset = load_dataset("squadshifts", dataset, trust_remote_code=True)["test"]

    # Note: initialize LLM but do not load the model
    llm = LLM(base, opening_message=opening_message)
    base_url = f"http://{vllm_hostname}:8000/v1"
    api_key = "token-abc123"
    vllm_client = AsyncOpenAI(
        base_url=base_url,
        api_key=api_key,
    )
    if "llama" in base:
        extra_body={"stop_token_ids":[128009]}
    else:
        extra_body = {}
    
    output_file = f"./datasets/{dataset_name}_filtered.csv"
    if os.path.exists(output_file):
        print(f"File {output_file} exists already. Not re-creating it", flush=True)
        exit(0)

    prompts = []
    questions = []
    for i, item in enumerate(dataset):
        if n_questions > 0 and i >= n_questions:
            break
        question = item['question']
        context = get_prompt_context(item)
        gt_answer = str(get_gt_answer(item))
        questions.append(question)

        prompt = f"""Here is a piece of text:
{context}

Here is a question related to the text:
{question}

Here is a list of valid ground-truth answers:
{gt_answer}

Please re-write the question such that it can be fully understood and it makes sense without access to the text. Output the new question inside <question> xml tags, like this:

<question>Rewritten question</question>"""

        messages = [Message(Role.USER, prompt)]
        messages = merge_messages(messages)
        prompt = llm.messages_to_prompt(messages)
        prompts.append(prompt)

    extra_body['top_k'] = 50
    extra_body['include_stop_str_in_output'] = True
    extra_body['skip_special_tokens'] = False
    print(f"Number of prompts: {len(prompts)}", flush=True)
    start_time = time.time()
    answers = asyncio.run(async_wrapper(vllm_client, model_id, prompts, extra_body, temperature, max_tokens=500))
    end_time = time.time()
    print(f"Generation time: {end_time - start_time:.4f} s", flush=True)
    assert len(prompts) == len(answers)
    
    with open(output_file, 'w', newline='') as file:
        print(f"Writing answers to {output_file}", flush=True)
        writer = csv.writer(file, delimiter=';')
        for i, answer in enumerate(answers): 
            # Regex pattern to match content between <question> and </question>
            pattern = r'<question>(.*?)</question>'
            prompt = prompts[i]
            
            # Find all matches
            matches = re.findall(pattern, answer)
            if len(matches) == 1 and len(matches[0]):
                answer = matches[0]
            else:
                print(f"Answer {answer} invalid. Resorting to original")
                answer = questions[i]
            writer.writerow([answer.replace('\n', '')])
            file.flush()
    print("Writing finished", flush=True)

if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
