import os
import re
from typing import List, Tuple, Dict
from nltk.tokenize import sent_tokenize
from vllm import LLM, SamplingParams
from retrieval import MyRetriever
from transformers import AutoTokenizer


class BaseModel:
    def __init__(self, model_path: str, gpu_memory_utilization: float = 0.9):
        self.model = LLM(
            model=model_path,
            gpu_memory_utilization=gpu_memory_utilization,
            tensor_parallel_size=1,
            distributed_executor_backend=None,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.sampling_params: SamplingParams = None

    def set_sampling_params(
        self,
        *,
        temperature: float = 0.7,
        max_tokens: int = 512,
        n: int = 1,
        seed: int = None,
        **kwargs,
    ):
        params = dict(temperature=temperature, max_tokens=max_tokens, n=n)
        if seed is not None:
            params["seed"] = seed
        params.update(kwargs)
        self.sampling_params = SamplingParams(**params)


class Actor(BaseModel):
    def create_messages(self, topic: str) -> List[Dict]:
        return [
            {
                "role": "system",
                "content": (
                    "You are an AI assistant that provides accurate and concise biographies of individuals. "
                    "Each biography should be exactly eight sentences long, highlighting key aspects of the person's life, "
                    "achievements, and significance."
                ),
            },
            {"role": "user", "content": f"Write a biography of {topic}."},
        ]

    def generate(self, topics: List[str], num_generations: int = None, is_test: bool = False) -> Dict[str, List[str]]:
        if num_generations is None:
            num_generations = 1 if is_test else 20
        prompts, topic_map = [], []
        for topic in topics:
            prompt = self.tokenizer.apply_chat_template(
                self.create_messages(topic),
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False,
            )
            prompts.append(prompt)
            topic_map.append(topic)
        self.set_sampling_params(temperature=1.2, n=num_generations)
        results = self.model.generate(prompts, self.sampling_params)
        out: Dict[str, List[str]] = {t: [] for t in topics}
        for res, topic in zip(results, topic_map):
            out[topic].extend([o.text for o in res.outputs])
        return out


class Critic(BaseModel):
    def __init__(
        self,
        model_path: str,
        gpu_memory_utilization: float = 0.8,
        retrieval_db_path: str = None,
        retrieval_type: str = "gtr-t5-large",
        k: int = 5,
        include_context: bool = False,
    ):
        super().__init__(model_path, gpu_memory_utilization)
        self.include_context = include_context
        if self.include_context:
            db_path = retrieval_db_path or os.environ.get("FACTSCORE_DB", "factscore_cache/enwiki-20230401.db")
            self.retriever = MyRetriever(db_path=db_path, retrieval_type=retrieval_type, k=k)

    def create_messages(self, topic: str, paragraph: str) -> List[Dict]:
        sentences = sent_tokenize(paragraph)
        numbered_sentences = []
        for i, s in enumerate(sentences, start=1):
            s = s.rstrip()
            if s and s[-1] not in ".!?":
                s += "."
            numbered_sentences.append(f"[{i}] {s}")
        numbered_paragraph = "\n".join(numbered_sentences)
        if self.include_context:
            pass
        else:
            example_para = (
                "[1] Albert Einstein was awarded the Nobel Prize in Physics in 1921 for his discovery of the photoelectric effect.\n"
                "[2] He was born in New York City, United States, and later moved to Europe where he continued his studies.\n"
                "[3] Einstein developed the theory of relativity, revolutionizing our understanding of space, time, and gravity.\n"
                "[4] His famous equation describes the equivalence of mass and energy."
            )
            example_answer = (
                "reason: Einstein was actually born in Ulm, Germany, not New York City.\n"
                "sentence: 2\n"
                "error_fact: Albert Einstein was born in New York City."
            )
            system_content = (
                "You are a factual checker. Based on your existing knowledge, identify exactly one sentence that contains "
                "the most clearly verifiable factual error in the paragraph. Return your answer in exactly three lines:\n"
                "reason: <briefly explaining what is wrong>\n"
                "sentence: N\n"
                "error_fact: F"
            )
            user_content = (
                "Here is an example to show the task.\n"
                "Find the sentence that contains the most clearly verifiable factual error in the paragraph about Albert Einstein.\n\n"
                "Example paragraph:\n"
                f"{example_para}\n\n"
                "Expected answer:\n"
                f"{example_answer}\n\n"
                f"Now apply the same procedure to the paragraph below about {topic}.\n\n"
                "Paragraph:\n"
                f"{numbered_paragraph}\n\n"
                "Answer:\n"
            )
        return [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}]

    def evaluate(self, topics: List[str], paragraphs: List[str], num_evaluations: int = 4, is_test: bool = False) -> List[Dict]:
        if num_evaluations is None:
            num_evaluations = 1 if is_test else 4
        prompts, meta = [], []
        for topic, para in zip(topics, paragraphs):
            prompt = self.tokenizer.apply_chat_template(
                self.create_messages(topic, para),
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False,
            )
            prompts.append(prompt)
            meta.append((topic, para))
        self.set_sampling_params(temperature=1.2, n=num_evaluations)
        results = self.model.generate(prompts, self.sampling_params)
        all_out = []
        for res, (topic, para) in zip(results, meta):
            answers = [o.text for o in res.outputs]
            all_out.append({"prompt_messages": self.create_messages(topic, para), "responses": answers})
        return all_out

    def extract_error_sentence_number(self, text: str) -> int:
        m = re.search(r"sentence\s*[:=]\s*(\d+)", text, flags=re.IGNORECASE | re.MULTILINE)
        if m:
            return int(m.group(1))
        first = text.strip().splitlines()[0].strip()
        if first.isdigit():
            return int(first)
        return 0

    def get_sentence_by_index(self, paragraph: str, index: int) -> str:
        sentences = sent_tokenize(paragraph)
        sentences = [s.strip() for s in sentences if s.strip()]
        if 1 <= index <= len(sentences):
            sentence = sentences[index - 1]
            return sentence if sentence[-1] in ".!?" else sentence + "."
        return ""

    def extract_independent_fact(self, text: str) -> str:
        m = re.search(r"error_fact\s*[:=]\s*(.+)", text, flags=re.IGNORECASE | re.MULTILINE)
        return m.group(1).strip() if m else ""

    def process_response(self, response: str, paragraph: str) -> Dict:
        error_num = self.extract_error_sentence_number(response)
        fact_str = self.extract_independent_fact(response)
        if error_num <= 0 or not fact_str:
            return {"judgment": "fail", "incorrect_sentence": None, "independent_fact": None}
        incorrect_sentence = self.get_sentence_by_index(paragraph, error_num)
        if incorrect_sentence:
            return {
                "judgment": f"incorrect_at_sentence_{error_num}",
                "incorrect_sentence": incorrect_sentence,
                "independent_fact": fact_str,
            }
        return {"judgment": "fail", "incorrect_sentence": None, "independent_fact": None}


class MetaVerifier(BaseModel):
    def __init__(self, model_path: str, gpu_memory_utilization: float = 0.5):
        super().__init__(model_path, gpu_memory_utilization)
        self.set_sampling_params(temperature=0.0, max_tokens=64, n=1, seed=42)

    def create_messages(self, sentence: str, error_fact: str) -> List[Dict]:
        system_content = (
            "You are an assistant that checks whether an extracted claim truly originates from a given sentence. "
            "Only decide if it is verbatim or a direct paraphrase of part of the sentence."
        )
        user_content = (
            f"Sentence: \"{sentence}\"\n"
            f"Extracted claim: \"{error_fact}\"\n\n"
            "Please answer in exactly two lines with no extra text:\n"
            "reason: <one-sentence explanation>\n"
            "verdict: yes / no\n"
        )
        return [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}]

    def _parse_one(self, text: str) -> Tuple[bool, str]:
        r_m = re.search(r"^reason\s*[:=]\s*(.+)$", text, flags=re.IGNORECASE | re.MULTILINE)
        reason = r_m.group(1).strip() if r_m else ""
        v_m = re.search(r"^verdict\s*[:=]\s*(yes|no)$", text, flags=re.IGNORECASE | re.MULTILINE)
        verdict = bool(v_m and v_m.group(1).lower() == "yes")
        return verdict, reason

    def batch_verify(self, pairs: List[Tuple[str, str]]) -> List[Tuple[bool, str]]:
        prompts = []
        for sentence, ef in pairs:
            msgs = self.create_messages(sentence, ef)
            prompt = self.tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False
            )
            prompts.append(prompt)
        results = self.model.generate(prompts, self.sampling_params)
        outputs = [res.outputs[0].text.strip() for res in results]
        return [self._parse_one(o) for o in outputs]
