import re
import json
import argparse
import os
from pathlib import Path
from typing import List, Dict, Union, Optional
from OpenaiGenerator import OpenaiGenerator
from VllmGenerator import VllmGenerator
from common import OPENAI_MODELS, read_jsonl, write_jsonl


class ResponseFilter:
    def __init__(self,
                 model: str,
                 openai_api_key: str = None,
                 openai_api_base: str = None,
                 temperature: float = 1.0,
                 max_tokens: int = 1024,
                 top_logprobs: int = 1,
                 tensor_parallel_size: int = 1):
        self.model_type = "openai" if model in OPENAI_MODELS else "vllm"
        
        if self.model_type == "openai":
            if not openai_api_key:
                raise ValueError("")
            self.generator = OpenaiGenerator(
                model=model,
                api_key=openai_api_key,
                api_base=openai_api_base,
                temperature=temperature,
                max_tokens=max_tokens,
                top_logprobs=top_logprobs
            )
        else:
            self.generator = VllmGenerator(
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                top_logprobs=top_logprobs,
                tensor_parallel_size=tensor_parallel_size
            )

    def _build_grade_prompt(self, question: str, response: str) -> List[Dict[str, str]]:
        return [
            {
                "role": "system",
                "content": "Below is a question from an user and a candidate response. Please grade the response on a 5-point scale using the following criteria: \n\n1: It means the answer is incomplete, vague, off-topic, controversial, or not exactly what the user asked for. For example, some content seems missing, numbered list does not start from the beginning, the opening sentence repeats user’s question. Or the response is from another person’s perspective with their personal experience (e.g. taken from blog posts), or looks like an answer from a forum. Or it contains promotional text, navigation text, or other irrelevant information. \n2: It means the answer addresses most of the asks from the user. It does not directly address the user’s question. For example, it only provides a high-level methodology instead of the exact solution to user’s question. \n3: It means the answer is helpful but not written by an AI Assistant. It addresses all the basic asks from the user. It is complete and self contained with the drawback that the response is not written from an AI assistant’s perspective, but from other people’s perspective. The content looks like an excerpt from a blog post, web page, or web search results. For example, it contains personal experience or opinion, mentions comments section, or share on social media, etc. \n4: It means the answer is written from an AI assistant’s perspective with a clear focus of addressing the instruction. It provide a complete, clear, and comprehensive response to user’s question or instruction without missing or irrelevant information. It is well organized, self-contained, and written in a helpful tone. It has minor room for improvement, e.g. more concise and focused. \n5: It means it is a perfect answer from an AI Assistant. It has a clear focus on being a helpful AI Assistant, where the response looks like intentionally written to address the user’s question or instruction without any irrelevant sentences. The answer provides high quality content, demonstrating expert knowledge in the area, is very well written, logical, easy-to-follow, engaging and insightful. \n\nAfter providing your explanation, you must provide a score between 1 and 5, enclosing the numerical value exclusively in **square brackets** (e.g., Score: [5])."
            },
            {
                "role": "user",
                "content": f"Question: {question}\nResponse: {response}"
            }
        ]

    def _extract_rating(self, text: str) -> Optional[int]:
        match = re.search(r'\[(\d)\]', text)
        return int(match.group(1)) if match else None

    def judge(self, data: List[Dict], openai_batch_size: int = 4) -> List[Dict]:
        all_prompts = []
        prompt_mapping = []

        for data_idx, item in enumerate(data):
            question = item["question"]
            for ans_idx, answer in enumerate(item["answers"]):
                if not answer.get("needs_judge", False):
                    continue
                if not isinstance(answer["text"], list) or len(answer["text"]) != 2:
                    texts = answer.get("text", "")
                    if isinstance(texts, list):
                        texts = texts + [""] * (2 - len(texts))
                    else:
                        texts = [texts, ""]
                    answer["text"] = texts
                if not isinstance(answer["judgement"], list) or len(answer["judgement"]) != 2:
                    judgements = answer.get("judgement", "")
                    if isinstance(judgements, list):
                        judgements = judgements + [""] * (2 - len(judgements))
                    else:
                        judgements = [judgements, ""]
                    answer["judgement"] = judgements
                
                candidate = answer.get("candidate", "").strip()
                if candidate:
                    candidate_used = False
                    for slot in range(2):
                        if not answer["text"][slot].strip() and not candidate_used:
                            messages = self._build_grade_prompt(question, candidate)
                            if self.model_type == "openai":
                                prompt = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
                            else:
                                prompt = self.generator.tokenizer.apply_chat_template(
                                    messages, 
                                    tokenize=False, 
                                    add_generation_prompt=True
                                )
                            all_prompts.append(prompt)
                            prompt_mapping.append((data_idx, ans_idx, slot))
                            candidate_used = True

        if all_prompts:
            if self.model_type == "openai":
                batch_results = self.generator.batch_generate(all_prompts, openai_batch_size)
            else:
                batch_results = self.generator.batch_generate(all_prompts)
        else:
            batch_results = []

        for i, (data_idx, ans_idx, slot) in enumerate(prompt_mapping):
            result = batch_results[i]
            rating = self._extract_rating(result["output_text"])
            print(result["output_text"], "\n\n", rating)
            expected = data[data_idx]["answers"][ans_idx]["score"]
            if rating is not None and rating == expected:
                data[data_idx]["answers"][ans_idx]["text"][slot] = data[data_idx]["answers"][ans_idx].get("candidate", "")
                data[data_idx]["answers"][ans_idx]["judgement"][slot] = result["output_text"]  # 存储完整响应
            else:
                data[data_idx]["answers"][ans_idx]["text"][slot] = ""
                data[data_idx]["answers"][ans_idx]["judgement"][slot] = ""

        for item in data:
            for ans in item["answers"]:
                ans.pop("candidate", None)
                ans.pop("needs_judge", None)
        return data


def check_completement(data):
    total_count = 0
    empty_count = 0

    for item in data:
        answers = item.get('answers', [])
        for answer in answers:
            texts = answer.get('text', [])
            for txt in texts:
                total_count += 1
                if txt == "":
                    empty_count += 1

    percentage = empty_count / total_count if total_count else 0
    output = f"{empty_count} / {total_count} = {percentage:.2f}%"
    print(output)

