"""testing_total_flow_transformers.py

1. **Claim extraction**       – identify atomic factual claims in an answer.
2. **Evidence retrieval**     – select primary & supporting sentences required
                                to verify each claim.
3. **Fact checking**          – classify each claim as *supported*, *not
                                supported*, *unverifiable*, or *irrelevant*.
"""

from __future__ import annotations

import re
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import pysbd
import torch
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GPTQConfig,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)

# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------

# model can be found here: https://drive.google.com/file/d/1ZzVkxuRRIX_h7QgKnCNVlmA31cjX1olJ/view?usp=sharing
claim_extra_tokenizer = AutoTokenizer.from_pretrained("claim_extractor")
claim_extra_model = AutoModelForCausalLM.from_pretrained(
    "claim_extractor", device_map="auto", torch_dtype=torch.float16
)

# model can be found here: https://drive.google.com/file/d/1ysPpsqtKx_W3hpGTBwdDjmce8EWcgqIT/view?usp=sharing
evi_extra_tokenizer = AutoTokenizer.from_pretrained("evidence_retrievel")
evi_extra_model = AutoModelForCausalLM.from_pretrained(
    "evidence_retrievel", device_map="auto", torch_dtype=torch.float16
)

# model can be found here: https://drive.google.com/file/d/1JZ_waTwMy9Uw6tgvOHjT06kL1QPPGkxi/view?usp=sharing
fact_checker_tokenizer = AutoTokenizer.from_pretrained("fact_checker")
fact_checker_model = AutoModelForCausalLM.from_pretrained(
    "fact_checker", device_map="auto", torch_dtype=torch.float16
)

# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------


def split_into_sentences(text: str) -> List[str]:
    """Segment *text* into sentences using PySBD.

    Returns
    -------
    List[str]
        A list of sentence strings **without** the character‑span metadata.
    """
    seg = pysbd.Segmenter(language="en", char_span=True)
    segments = seg.segment(text)

    sentences: List[str] = []
    for segment in segments:
        sentences.append(segment.sent)
    return sentences


# ---------------------------------------------------------------------------
# 1. Claim extraction
# ---------------------------------------------------------------------------


def _get_claim_extractor_prompt(question: str, answer: str) -> str:
    """Compose the system prompt for the claim‑extraction model."""
    return f"""You are a helpful assistant specialised in extracting factual claims from a given Question and Answer. ...  # (prompt truncated for brevity in docstring)

*** FULL prompt is preserved below ***

You are a helpful assistant specialized in extracting factual claims from a given Question and Answer. Your task is to return each factual claim stated or implied in the Answer, along with the sentence number it originates from. Each sentence in the Answer is already numbered for your convenience.

If the Answer is short or minimal, use the Question for necessary context to reconstruct a clear and self-contained factual claim. Always combine Question and Answer when needed to form a meaningful claim.

Definitions:
A **claim** is a statement that asserts something as true or false and can be verified or refuted using external evidence.

Do not include:
- Opinions
- Vague or rhetorical expressions
- Hypothetical conditionals (unless they assert a verifiable fact)

Each claim must:
- Be concise
- Contain only one main idea (atomic)
- Be standalone and understandable without referring to the full Answer

Note: If a single sentence contains multiple factual claims, extract each one independently with the correct sentence number.

---

**Input Format**:

**Question**: {question}

**Answer**: {answer}

---

**Output Format**:
You should return each extracted factual claim wrapped in <claim> tags.
Inside each <claim> block:

The first line contains the text of the claim in plain language.

The second part is wrapped in <sentence> tags, which indicate the sentence number from the original Answer where the claim comes from.

<claim>
{{extracted factual claim}}
<sentence>{{sentence number from which the claim was derived}}</sentence>
</claim>
<claim>
{{extracted factual claim}}
<sentence>{{sentence number from which the claim was derived}}</sentence>
</claim>
....
output: """


def _infer_claim_extractor(messages: List[dict]) -> str:
    """Run the claim‑extraction model and return the decoded string."""
    text = claim_extra_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    model_inputs = claim_extra_tokenizer([text], return_tensors="pt").to(
        claim_extra_model.device
    )

    generated_ids = claim_extra_model.generate(**model_inputs, max_new_tokens=1024)

    # Trim the prompt tokens from each sequence
    generated_ids = [
        output_ids[len(input_ids) :]  # noqa: E203 (black style vs. PEP 8)
        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    return claim_extra_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
        0
    ]


def _number_answer_sentences(answer: str) -> str:
    """Prefix each sentence with an XML‑style index: <1>, <2>, ..."""
    numbered_sentences = [
        f"<{idx}> {sent}" for idx, sent in enumerate(split_into_sentences(answer), 1)
    ]
    return "".join(numbered_sentences)


def _parse_claim_blocks(text: str) -> List[Tuple[str, int]]:
    """Extract (claim, sentence_number) pairs from the model's XML output."""
    claim_pattern = re.compile(
        r"(?s)<claim>\s*(.*?)\s*<sentence>\s*(\d+)\s*</sentence>\s*</claim>",
        re.IGNORECASE,
    )
    return [(claim.strip(), int(num)) for claim, num in claim_pattern.findall(text)]


def extract_claims(question: str, answer: str) -> List[Tuple[str, int]]:
    """End‑to‑end helper that returns structured claims for *answer*."""
    prompt = _get_claim_extractor_prompt(question, _number_answer_sentences(answer))
    messages = [{"role": "user", "content": prompt}]
    raw_output = _infer_claim_extractor(messages)
    return _parse_claim_blocks(raw_output)


# ---------------------------------------------------------------------------
# 2. Evidence extraction
# ---------------------------------------------------------------------------


def _get_evidence_extractor_prompt(
    question: str, answer: str, evidence: str
) -> str:
    """Compose the system prompt for evidence selection."""
    # NOTE: The template is long; retained verbatim for fidelity.
    return f"""
You are an intelligent assistant designed to identify **all necessary evidence sentences index numbers** required to fact-check a given **Question and Answer pair**.
... (full prompt text unchanged) ...
Output: """


def _wrap_evidence_sentences(text: str) -> Tuple[str, List[str]]:
    """Tag each sentence in *text* as <1>...</1>, <2>...</2>, ..."""
    sentences = split_into_sentences(text)
    xml_wrapped = "".join(f"<{i}>{s}</{i}>" for i, s in enumerate(sentences, 1))
    return xml_wrapped, sentences


def _infer_evidence_extractor(messages: List[dict]) -> str:
    """Run the evidence‑retrieval model and return decoded output."""
    text = evi_extra_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    model_inputs = evi_extra_tokenizer([text], return_tensors="pt").to(
        evi_extra_model.device
    )
    generated_ids = evi_extra_model.generate(**model_inputs, max_new_tokens=1024)
    generated_ids = [
        output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    return evi_extra_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]


def _parse_main_and_support(xml_text: str) -> List[Tuple[int, List[int]]]:
    """Return list of (main_sentence_idx, [support_idx, ...]) tuples."""
    pattern = r"<main>\s*sentence_number=(\d+)\s*<support>(.*?)</support>\s*</main>"
    matches = re.findall(pattern, xml_text)
    result: List[Tuple[int, List[int]]] = []
    for main_sentence, support in matches:
        support_list = [int(s.strip()) for s in support.split(",") if s.strip()]
        result.append((int(main_sentence), support_list))
    return result


def _collect_evidence_lines(
    full_text: str, indices: List[Tuple[int, List[int]]]
) -> List[Tuple[str, List[str]]]:
    """Map sentence indices back to actual strings; warn on out‑of‑range."""
    sentences = split_into_sentences(full_text)
    gathered: List[Tuple[str, List[str]]] = []

    for main_idx, support_indices in indices:
        # Convert to zero‑based indices
        main_sentence = (
            sentences[main_idx - 1]
            if 0 < main_idx <= len(sentences)
            else f"<index {main_idx} out of range>"
        )
        support_sentences = [
            sentences[i - 1]
            for i in support_indices
            if 0 < i <= len(sentences)
        ]
        gathered.append((main_sentence, support_sentences))
    return gathered


def retrieve_evidence(
    question: str, answer: str, evidence_corpus: str
) -> List[Tuple[str, List[str]]]:
    """End‑to‑end helper: select evidence lines for *answer* from *corpus*."""
    wrapped_evidence, _ = _wrap_evidence_sentences(evidence_corpus)
    prompt = _get_evidence_extractor_prompt(
        question=question, answer=answer, evidence=wrapped_evidence
    )
    messages = [{"role": "user", "content": prompt}]
    xml_output = _infer_evidence_extractor(messages)
    indices = _parse_main_and_support(xml_output)
    return _collect_evidence_lines(evidence_corpus, indices)


# ---------------------------------------------------------------------------
# 3. Fact checking
# ---------------------------------------------------------------------------


def _get_factchecker_prompt(question: str, claim: str, evidence: List[Tuple[str, List[str]]]) -> str:
    """Compose prompt that asks the fact‑checker to label *claim*."""
    evidence_text_parts: List[str] = []
    for idx, (primary, supports) in enumerate(evidence, 1):
        support_block = "\n".join(supports)
        evidence_text_parts.append(
            f"""Primary Evidence {idx}:
{primary}
Support Evidence {idx}:
{support_block}

""")
    evidence_text = "".join(evidence_text_parts)

    # NOTE: Prompt body unchanged (truncated above for brevity).
    return f"""You are an intelligent fact-checking system. ... (full prompt) ...

**question**: 
{question}

**sub-claim**:
{claim}

**evidence**: 
{evidence_text}


Output: """


def _infer_factchecker(messages: List[dict]) -> str:
    """Run the fact‑checking model and return decoded output."""
    text = fact_checker_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    model_inputs = fact_checker_tokenizer([text], return_tensors="pt").to(
        fact_checker_model.device
    )
    generated_ids = fact_checker_model.generate(**model_inputs, max_new_tokens=1024)
    generated_ids = [
        output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    return fact_checker_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]


def _extract_label_and_errors(text: str) -> Tuple[str | None, List[str]]:
    """Parse <label> and <error> blocks from *text*."""
    label_match = re.search(r"<label>\s*(.*?)\s*</label>", text, re.IGNORECASE | re.DOTALL)
    label = label_match.group(1).strip() if label_match else None

    error_match = re.search(r"<error>\s*(.*?)\s*</error>", text, re.IGNORECASE | re.DOTALL)
    if error_match:
        error_content = error_match.group(1).strip()
        errors = [phrase.strip() for phrase in error_content.split(",") if phrase.strip()]
    else:
        errors = []
    return label, errors


# ---------------------------------------------------------------------------
# High‑level orchestration helpers
# ---------------------------------------------------------------------------


def _chunk_text(text: str, *, chunk_size: int = 7000) -> List[str]:
    """Chunk long *text* deterministically (no overlap) for model input."""
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        separators=[".", "?", "!", "\n", " ", ""],
        chunk_overlap=0,
        keep_separator=True,
    )
    return splitter.split_text(text)


def classify_claim(
    question: str, claim: Tuple[str, int], evidence_corpus: str, *, threshold: float = 0.75
) -> str:
    """Return 'fact' if claim is supported by ≥ *threshold* chunks, else 'not fact'."""
    chunks = _chunk_text(evidence_corpus)
    evidence_sets: List[Tuple[str, List[str]]] = []

    for chunk in chunks:
        evidence_sets.extend(retrieve_evidence(question, claim[0], chunk))

    claim_evaluation_results: List[str] = []
    for evidence in evidence_sets:
        prompt = _get_factchecker_prompt(question, claim[0], [evidence])
        label_output = _infer_factchecker([{"role": "user", "content": prompt}])
        label, _ = _extract_label_and_errors(label_output)
        if label:
            claim_evaluation_results.append(label)

    if not claim_evaluation_results:
        return "unverifiable"

    supported_fraction = (
        sum(lab == "supported" for lab in claim_evaluation_results)
        / len(claim_evaluation_results)
    )
    return "fact" if supported_fraction >= threshold else "not fact"


def classify_answer(
    question: str, answer: str, evidence_corpus: str, *, threshold: float = 0.75
) -> str:
    """Label entire *answer* as 'fact' if the majority of its claims are supported."""
    claims = extract_claims(question, answer)
    if evidence_corpus.startswith("[") and evidence_corpus.endswith("]"):
        # Canonicalise a stringified list back to plain text.
        evidence_corpus = "\n".join(eval(evidence_corpus))

    results = [
        classify_claim(question, claim, evidence_corpus, threshold=threshold)
        for claim in claims
    ]
    supported_fraction = results.count("fact") / len(results) if results else 0.0
    return "fact" if supported_fraction >= threshold else "not fact"


# ---------------------------------------------------------------------------
# Demo / CLI entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Quick sanity‑check demo.")
    parser.add_argument("--question", default="", help="Question for context.")
    parser.add_argument("--answer", required=True, help="Answer to be fact‑checked.")
    parser.add_argument("--evidence", required=True, help="Evidence corpus file path.")
    args = parser.parse_args()

    # Read evidence corpus from a text file
    evidence_text = Path(args.evidence).read_text(encoding="utf‑8")

    verdict = classify_answer(args.question, args.answer, evidence_text)
    print(f"\nOverall verdict: {verdict.upper()}\n")
