import json
import argparse
import os
import sys
from typing import List, Dict, Optional
from copy import deepcopy

from common import OPENAI_MODELS, read_jsonl, write_jsonl
from OpenaiGenerator import OpenaiGenerator
from VllmGenerator import VllmGenerator


class ResponseGenerator:
    def __init__(self,
                 model: str,
                 openai_api_key: Optional[str] = None,
                 openai_api_base: Optional[str] = None,
                 temperature: float = 1.0,
                 max_tokens: int = 1024,
                 top_logprobs: int = 1,
                 tensor_parallel_size: int = 1,
                 openai_batch_size: int = 4):

        self.model = model
        self.openai_batch_size = openai_batch_size

        if model in OPENAI_MODELS:
            self.model_type = "openai"
            if not openai_api_key:
                raise ValueError("")
            self.generator = OpenaiGenerator(
                model=model,
                api_key=openai_api_key,
                api_base=openai_api_base,
                temperature=temperature,
                max_tokens=max_tokens,
                top_logprobs=top_logprobs
            )
        else:
            self.model_type = "vllm"
            self.generator = VllmGenerator(
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_logprobs=top_logprobs,
                tensor_parallel_size=tensor_parallel_size
            )
            self.tokenizer = self.generator.tokenizer

        self.score_prompts = {
            1: [
                {
                "role": "system",
                "content": "You are an AI assistant needing to generate a response that would receive a score of 1 out of 5. It means the answer is incomplete vague, off-topic, controversial, or not exactly what the user asked for. For example, some content seems missing, numbered list does not start from the beginning, the opening sentence repeats user’s question. Or the response is from another person’s perspective with their personal experience (e.g. taken from blog posts), or looks like an answer from a forum. Or it contains promotional text, navigation text, or other irrelevant information. **ONLY provide the Response to the Question. Prohibit disclosing scoring details.**"
                },
                {"role": "user", "content": "Question: {question}\nAnswer:"}
            ],
            2: [
                {
                "role": "system",
                "content": "You are an AI assistant needing to generate a response that would receive a score of 2 out of 5. It means the answer addresses most of the asks from the user. It does not directly address the user’s question. For example, it only provides a high-level methodology instead of the exact solution to user’s question. **ONLY provide the Response to the Question. Prohibit disclosing scoring details.**"
                },
                {"role": "user", "content": "Question: {question}\nAnswer:"}
            ],
            3: [
                {
                "role": "system",
                "content": "You are an AI assistant needing to generate a response that would receive a score of 3 out of 5. It means the answer is helpful but not written by an AI Assistant. It addresses all the basic asks from the user. It is complete and self contained with the drawback that the response is not written from an AI assistant’s perspective, but from other people’s perspective. The content looks like an excerpt from a blog post, web page, or web search results. For example, it contains personal experience or opinion, mentions comments section, or share on social media, etc. **ONLY provide the Response to the Question. Prohibit disclosing scoring details.**"
                },
                {"role": "user", "content": "Question: {question}\nAnswer:"}
            ],
            4: [
                {
                "role": "system",
                "content": "You are an AI assistant needing to generate a response that would receive a score of 4 out of 5. It means the answer is written from an AI assistant’s perspective with a clear focus of addressing the instruction. It provide a complete, clear, and comprehensive response to user’s question or instruction without missing or irrelevant information. It is well organized, self-contained, and written in a helpful tone. It has minor room for improvement, e.g. more concise and focused. **ONLY provide the Response to the Question. Prohibit disclosing scoring details.**" 
                },
                {"role": "user", "content": "Question: {question}\nAnswer:"}
            ],
            5: [
                {
                "role": "system",
                "content": "You are an AI assistant needing to generate a response that would receive a score of 5 out of 5. It means it is a perfect answer from an AI Assistant. It has a clear focus on being a helpful AI Assistant, where the response looks like intentionally written to address the user’s question or instruction without any irrelevant sentences. The answer provides high quality content, demonstrating expert knowledge in the area, is very well written, logical, easy-to-follow, engaging and insightful. **ONLY provide the Response to the Question. Prohibit disclosing scoring details.**"
                },
                {"role": "user", "content": "Question: {question}\nAnswer:"}
            ]
        }

    def _build_prompt(self, score: int, question: str) -> str:

        template = deepcopy(self.score_prompts[score])

        template[1]["content"] = template[1]["content"].format(question=question)
        if self.model_type == "openai":
            prompt = "\n".join([f"{m['role']}: {m['content']}" for m in template])
            return prompt
        else:
            if "gemma" in os.path.basename(os.path.normpath(self.model)):
                template = [{"role": "user", "content": template[0]["content"] + "\n\n" + template[1]["content"]}]
            prompt_tokens = self.tokenizer.apply_chat_template(
                template,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt"
            )
            return self.tokenizer.decode(prompt_tokens[0])

    def generate_batch(self, questions: List[Dict]) -> List[Dict]:

        all_prompts = []
        mapping = []
        for idx, question in enumerate(questions):
            q_text = question["question"]
            for score in [1, 2, 3, 4, 5]:
                prompt = self._build_prompt(score, q_text)
                all_prompts.append(prompt)
                mapping.append((idx, score))

        if self.model_type == "openai":
            outputs = self.generator.batch_generate(all_prompts, openai_batch_size=self.openai_batch_size)
        else:
            outputs = self.generator.batch_generate(all_prompts)
        num_scores = 5
        results = []
        for i, question in enumerate(questions):
            answers = []
            for j in range(num_scores):
                index = i * num_scores + j
                text = outputs[index]["output_text"]
                answers.append({
                    "score": j + 1,
                    "text": text
                })
            results.append({
                "question_id": question["question_id"],
                "question": question["question"],
                "answers": answers
            })
        return results

    @staticmethod
    def load_questions(input_file: str) -> List[Dict]:

        data = read_jsonl(input_file)
        questions = []
        for item in data:
            questions.append({
                "question_id": item["question_id"],
                "question": item["question"]
            })
        return questions

