import asyncio
from collections import Counter
from functools import partial
import httpx
import logging
import os
import re
import sys
import time  # noqa

import openai
from openai import AsyncOpenAI
from datasets import load_dataset
from transformers import AutoTokenizer  # noqa

from core.llm import LLM
from core.messages import Message, Role
from evaluation.utils import async_wrapper

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
MAX_FAILURES = 10

def extract_questions(input_string):
    questions = re.findall(r'<question>(.*?)<\/question>', input_string, re.DOTALL)
    return questions


def generate_prompt_async(context, llm):
    prompt = f"""Here is a paragraph of text:
{context}

Please generate challenging five trivia questions based on this paragraph. Do not make the questions multiple-choice. Do not assume that the person answering the questions has access to the paragraph. The questions must be understandable without access to the text. Do not output anything except the questions and format your output as in the followimg example:
<question>What is the capital of Japan?</question>
<question>How many months are there in a year?</question>
<question>What was the first name of Reagan?</question>
<question>How many goals did Messi score during the calendar year 2012</question>
<question>Where is the Santa Monica pier located?</question>"""
    
    messages = [Message(Role.USER, prompt)]
    vllm_prompt = llm.messages_to_prompt(messages)
    return vllm_prompt

            
async def generate_questions_async(client, model, vllm_prompt, extra_body, temperature, max_tokens, index, total, question_calls):
    all_questions = []
    i = 0
    failures = 0

    while i < question_calls:
        response = await client.completions.create(
            model=model,
            prompt=vllm_prompt,
            stream=False,
            extra_body=extra_body,
            temperature=temperature,
            max_tokens=max_tokens,
        )
        completion = response.choices[0].text.strip()
        questions = re.findall(r'<question>(.*?)<\/question>', completion, re.DOTALL)
        if len(questions) and questions[0].strip().lower() != 'q1':
            i += 1
            all_questions += questions
        else:
            failures += 1

        if failures == MAX_FAILURES:
            print(f"Failed to create proper questions for vllm prompt {vllm_prompt}", flush=True)
            break
    print(f"Generated question for prompt {index+1}", flush=True)
    return all_questions


def main(
    base: str = "llama3-8b-instruct",
    dataset: str = "nyt",
    max_items: int = 1000,
    max_tokens: int = 512,
    processes: int = -1,
    prompt_type: str = "default",
    questions: int = 30,
    temperature: float = 1.5,
    test: bool = False,
    verbose: bool = False,
    vllm_hostname: str = "",
):
    question_calls = questions // 5  # We generate five questions at a time
    if base == "llama3-8b-instruct" or base == "llama3-70b-instruct":
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
        if base == "llama3-8b-instruct":
            model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
        elif base == "llama3-70b-instruct":
            model_id = "meta-llama/Meta-Llama-3-70B-Instruct"
    elif base == "qwen2.5-7b-instruct":
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        model_id = "Qwen/Qwen2.5-7B-Instruct"
    elif base == "qwen2.5-14b-instruct":
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        model_id = "Qwen/Qwen2.5-14B-Instruct"
    elif base == "qwen2.5-3b-instruct":
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        model_id = "Qwen/Qwen2.5-3B-Instruct"
    else:
        raise NotImplementedError(f"Unknown base {base}")

    output_directory_name = dataset
    dataset = load_dataset("squadshifts", dataset, trust_remote_code=True)["test"]

    # Note: initialize LLM but do not load the model
    llm = LLM(base, opening_message=opening_message)
    base_url = f"http://{vllm_hostname}:8000/v1"
    api_key = "token-abc123"
    vllm_client = AsyncOpenAI(
        base_url=base_url,
        api_key=api_key,
    )
    if "llama3" in base.lower():
        extra_body={"stop_token_ids":[128009]}
    else:
        extra_body = {}
    
    extra_body['top_k'] = 50
    extra_body['include_stop_str_in_output'] = True
    extra_body['skip_special_tokens'] = False
    
    if max_items > 0:
        output_filename = f"questions_{question_calls}_{temperature}_{max_items}.csv"
    else:
        output_filename = f"questions_{question_calls}_{temperature}.csv"
    output_dir = f"questions/{output_directory_name}/{base}"
    os.makedirs(output_dir, exist_ok=True)
    output_file = output_dir + "/" + output_filename

    if os.path.exists(output_file):
        print(f"Question path {output_file} exists already, exiting")
        exit(1)

    questions_written = 0
    print(f"Output file for questions: {output_file}", flush=True)
    print(f"Starting to generate questions to be written into {output_file}", flush=True)
    contexts = []
    prompts = []
    for i, item in enumerate(dataset):
        if max_items and i >= max_items:
            break
        contexts.append(item['context'])

    for context in contexts:
        prompts.append(generate_prompt_async(context, llm))

    print(f"Number of text snippets: {len(contexts)}", flush=True)
    print(f"Number of prompts: {len(prompts)}", flush=True)

    start_time = time.time()    
    questions = asyncio.run(async_wrapper(vllm_client, model_id, prompts,
                                          extra_body, temperature, max_tokens,
                                          custom_fnc=generate_questions_async,
                                          custom_fnc_extra_kwargs={"question_calls": question_calls}))
    end_time = time.time()
    print(f"Generation time: {end_time - start_time:.4f} s", flush=True)

    assert len(contexts) == len(questions)
    with open(output_file, 'w', newline='') as f:
        for c, item_questions in zip(contexts, questions):
            for q in item_questions:
                questions_written += 1
                f.write(q.replace('\n', ' ').replace(';', ',') + ';' + c.replace('\n',' ').replace(';', ',') + '\n')
                f.flush()
    print(f"{questions_written} questions generated and written into {output_file}, exiting", flush=True)

if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)

