"""A prefect managed script to run the conformal backoff algorithm.
"""

import click
import os
from vllm import LLM, SamplingParams
from functools import partial
import io
import ujson as json
import numpy as np
from tqdm import tqdm
from typing import Text, List, Dict, Any
from prefect import task, flow
from prefect.tasks import task_input_hash
from transformers import AutoTokenizer
import time


_TENSOR_PARALLEL_SIZE_ = int(os.environ["TENSOR_PARALLEL_SIZE"])


@task(cache_key_fn=task_input_hash)
def get_stats(input_file_path: Text) -> Dict[Text, Any]:
    """
    """
    with open(input_file_path, "r") as f:
        data = [json.loads(line) for line in tqdm(f)]

    # yes_no_answer_dist
    counter = {}

    for d in data:
        for annotation in d['annotations']:
            if annotation['yes_no_answer'] not in counter:
                counter[annotation['yes_no_answer']] = 0
            counter[annotation['yes_no_answer']] += 1

    return {
        "num_instances": len(data),
        "keys": list(data[0].keys()),
        "avg_question_tokens": np.mean([len(d['question_tokens']) for d in data]).item(),
        "yes_no_answer_dist": counter
    }


@task(cache_key_fn=task_input_hash)
def extract_questions(
    input_file_path: Text,
    output_file_path: Text
) -> List[Dict[Text, Any]]:
    """
    """
    with open(input_file_path, "r") as f:
        data = [json.loads(line) for line in tqdm(f)]
    
    questions = [{"question": item['question_text'], "example_id": item['example_id']} for item in filter(lambda x: all([y['yes_no_answer'] == "NONE" for y in x['annotations']]), data)]

    with open(output_file_path, "w") as f:
        for item in questions:
            f.write(json.dumps(item) + "\n")
            
    return questions


@task(
    cache_key_fn=task_input_hash,
    version="0.1.2"
)
def sample_answers(
    questions: List[Dict[Text, Any]],
    k: int,
    temperature: float,
    top_p: float,
    output_data_path: Text
) -> List[Dict[Text, Any]]:
    """Generate answers with vLLM
    """
    llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct", tensor_parallel_size=_TENSOR_PARALLEL_SIZE_)
    
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, stop_token_ids=terminators, max_tokens=100)

    # convert questions to chat messages
    processed_question_messages = [
        [
            {"role": "user", "content": "Answer the search-query with a minimum phrase with no additional information: " + item['question']}
        ]  for item in questions
    ]
    
    # repeat the questions k times
    processed_question_prompts = tokenizer.apply_chat_template(
        processed_question_messages,
        add_generation_prompt=True,
        return_tensors=False,
        tokenize=False
    )
    
    repeated_question_prompts = [q for q in processed_question_prompts for _ in range(k)]

    outputs = llm.generate(prompts=repeated_question_prompts, sampling_params=sampling_params)
    texts = [output.outputs[0].text for output in outputs]
    
    grouped_text = [texts[i:i+k] for i in range(0, len(texts), k)]

    with open(output_data_path, 'w') as file_:
        for question, prompt, answers in zip(questions, processed_question_prompts, grouped_text):
            file_.write(json.dumps({**question, "answers": answers, "prompt": prompt}) + '\n')


@click.command()
@click.option("--input-data-path", type=str, help="Path to the input data.", required=True)
@click.option("--question-data-path", type=str, help="Path to the output data.", required=False, default='data/NQ/questions.jsonl')
@click.option("--answer-data-path", type=str, help="Path to the output data.", required=False, default='data/NQ/answers.jsonl')
@click.option("--temperature", type=float, help="Temperature for sampling.", required=False, default=0.7)
@click.option("--top-p", type=float, help="Top-p for sampling.", required=False, default=0.95)
@flow
def main(
    input_data_path,
    question_data_path,
    answer_data_path,
    temperature,
    top_p
):
    """
    """
    string_io = io.StringIO()
    json.dump(get_stats(input_data_path), string_io, indent=4)
    print(string_io.getvalue())
    
    # extract questions with no yes/no answers
    questions = extract_questions(input_data_path, question_data_path)
    answered = sample_answers(questions, 100, temperature, top_p, answer_data_path)
    
    
if __name__ == "__main__":
    main()