import re
import matplotlib.pyplot as plt
import os, json
from typing import List, Tuple, Optional, Dict, Iterable
from openai import OpenAI

from typing import Sequence, List

def nli_judge(pairs: Sequence[Sequence[str]]) -> str:
    """
    Input: pairs like [[p1, p2], [p1, p2], ...] (this function also supports >2 pairs)
    Task: choose the pair with the strongest mutual entailment / semantic equivalence.
    The last line must output a 1-based index as <final>1</final>/<final>2</final>/...
    """
    prompt = (
        "You are an NLI (natural language inference) judge.\n"
        "Task: Compare the following TWO pairs of texts and decide which pair has STRONGER mutual entailment.\n"
        "Rules:\n"
        "- Mutual entailment means both directions p1 ⇒ p2 and p2 ⇒ p1 hold semantically.\n"
        "- Output only the index of the better pair.\n"
        "On the LAST line, output ONLY the winner index (1 or 2) wrapped as <final>k</final>.\n\n"
        f"Pair 1:\n p1: {pairs[0][0]}\n p2: {pairs[0][1]}\n\n"
        f"Pair 2:\n p1: {pairs[1][0]}\n p2: {pairs[1][1]}\n\n"
        "Judge result:"
    )
    return prompt


def parse_final_index_from_output(nli_output, default=1) -> int:
    """
    Extract integer k from <final>k</final> in the NLI output.
    - nli_output can be a str or a dict (with keys like "text"/"raw")
    - If extraction fails, return `default` (default=1)
    """
    if isinstance(nli_output, dict):
        text = nli_output.get("text") or nli_output.get("raw") or str(nli_output)
    else:
        text = str(nli_output)

    # Standard form: <final>  12  </final>
    m = re.search(r"<\s*final\s*>\s*([0-9]+)\s*<\s*/\s*final\s*>", text, flags=re.I)
    if m:
        try:
            return int(m.group(1))
        except ValueError:
            pass

    # Fallback: grab the last plain number in the string
    nums = re.findall(r"\b([0-9]+)\b", text)
    if nums:
        try:
            return int(nums[-1])
        except ValueError:
            pass
    return int(default)



# -------- 1) Build prompts with <final>...</final> --------
def make_prompt_with_final_option(question: str) -> str:
    """
    Multiple-choice prompt (same structure as the free-form version):
    - Output four sections: Premises/Reasoning/Conclusion.
    - The last line must output ONLY the LABEL (A/B/C/...) of the chosen option (without the option content).
    - Wrap the final label with <final>...</final>.
    """
    banned = "No valid answer, N/A, Unknown, None"
    return (
        "You are a careful reasoner. Think first, then answer.\n"
        "Rules:\n"
        "1) Select ONLY ONE option LABEL that appears in the question (e.g., A/B/C/...).\n"
        "2) Do NOT include the option content in <final>; output the LABEL only.\n"
        "3) Keep the content concise; no placeholders such as: " + banned + ".\n"
        "4) The LAST line must contain ONLY the answer label wrapped with <final>YOUR ANSWER</final>.\n"
        "\n"
        "Output format: Use EXACTLY these sections and headers:\n"
        "1) Premise/Evidence:\n"
        "2) Reasoning:\n"
        "3) Conclusion:\n"
        "\n"
        f"Question: {question}\n"
        "Answer:"
    )


def make_explain_prompt(question: str, answer_text: str) -> str:
    """
    Explanation prompt (concise):
    Uses the same section structure; the very last line must reproduce the original answer inside <final>...</final>.
    """
    answer_text = str(answer_text or "").strip()
    banned = "No valid answer, N/A, Unknown, None"
    return (
        "You already answered the question. Your ONLY task is to explain why the given answer is correct.\n"
        "Rules:\n"
        "1) DO NOT change, contradict, paraphrase, or propose any other answer.\n"
        "2) Keep the content concise; no placeholders such as: " + banned + ".\n"
        "3) The LAST line must contain ONLY the ORIGINAL given answer wrapped with <final>YOUR ANSWER</final>.\n"
        "\n"
        "Output format: Use EXACTLY these sections and headers:\n"
        "1) Premise/Evidence:\n"
        "2) Explanation:\n"
        "3) Conclusion:\n"
        "\n"
        f"Question: {question}\n"
        f"Given Answer: {answer_text}\n"
        "Explanation:"
    )



def make_prompt_with_final_free(question: str) -> str:
    """
    Reasoning prompt (concise):
    Uses the same section structure as explanation; the last line must output the final answer in <final>...</final>.
    """
    banned = "No valid answer, N/A, Unknown, None"
    return (
        "You are a careful reasoner. Think first, then answer.\n"
        "Rules:\n"
        "1) Keep the content concise; no placeholders such as: " + banned + ".\n"
        "2) The LAST line must contain ONLY the final answer wrapped with <final>YOUR ANSWER</final>.\n"
        "\n"
        "Output format: Use EXACTLY these sections and headers:\n"
        "1) Premise/Evidence:\n"
        "2) Reasoning:\n"
        "3) Conclusion:\n"
        "\n"
        f"Question: {question}\n"
        "Answer:"
    )



# -------- 2) Extract the <final>...</final> answer from full text --------
def extract_final_span(full_text: str) -> str | None:
    """
    Return the content after <final>. If </final> exists, extract the substring between;
    otherwise, take until the end of the text.
    If the result contains '<', strip from the first '<' onward.
    """
    if not full_text:
        return None

    # Find <final>
    m = re.search(r"<final>\s*(.*)", full_text, flags=re.IGNORECASE | re.DOTALL)
    if not m:
        return None

    ans = m.group(1)
    # Cut off at </final> if present
    ans = ans.split("</final>", 1)[0]

    # Remove the first '<' and everything after it
    if "<" in ans:
        ans = ans.split("<", 1)[0]

    return ans.strip()

# -------- 3) Align the answer span to token indices --------
def align_answer_to_tokens(result: dict, answer_span: str) -> list[int]:
    """
    Align `answer_span` to the token sequence using positions in result["text"].
    Steps:
      1) Find the character range of `answer_span` in full_text.
      2) Reconstruct token character ranges and collect indices that overlap with the span.
    Returns: a list of token indices (possibly empty).
    """
    full_text = result.get("text", "")
    tokens = result.get("tokens", [])
    if not full_text or not tokens or not answer_span:
        return []

    # Locate the character range of the answer in the original text
    start_char = full_text.find(answer_span)
    if start_char < 0:
        # Try again with simple normalization (remove whitespace/newlines)
        norm_full = re.sub(r"\s+", "", full_text)
        norm_ans = re.sub(r"\s+", "", answer_span)
        start_char_norm = norm_full.find(norm_ans)
        if start_char_norm < 0:
            return []
        # Map normalized position back to the original text (approximate)
        i_full = 0
        i_norm = 0
        start_char = None
        while i_full < len(full_text) and i_norm < start_char_norm:
            if not full_text[i_full].isspace():
                i_norm += 1
            i_full += 1
        start_char = i_full
        # Estimate the end position
        i_norm_end = start_char_norm + len(norm_ans)
        while i_full < len(full_text) and i_norm < i_norm_end:
            if not full_text[i_full].isspace():
                i_norm += 1
            i_full += 1
        end_char = i_full
    else:
        end_char = start_char + len(answer_span)

    # Rebuild token character ranges and detect overlap
    hit_indices = []
    cursor = 0
    for idx, tk in enumerate(tokens):
        tok_text = tk["token"]
        # Some APIs include '\n' as literal; we accumulate by token text length directly
        tok_len = len(tok_text)
        t_start = cursor
        t_end = cursor + tok_len
        overlap = not (t_end <= start_char or t_start >= end_char)
        if overlap:
            hit_indices.append(idx)
        cursor = t_end
    return hit_indices

# Helper: split indices into contiguous runs (e.g., [3,4,5,  9,10,  15] -> [[3,4,5],[9,10],[15]])
def _contiguous_runs(indices: list[int]) -> list[list[int]]:
    if not indices:
        return []
    idxs = sorted(set(indices))
    runs = []
    start = prev = idxs[0]
    for i in idxs[1:]:
        if i == prev + 1:
            prev = i
        else:
            runs.append(list(range(start, prev + 1)))
            start = prev = i
    runs.append(list(range(start, prev + 1)))
    return runs

def plot_token_probs_with_answer(result: dict, answer_token_indices, title: str = "Token Probabilities"):
    """
    - If a flat list is given -> automatically split into contiguous runs, then plot.
    - If a list of lists is given (each span as a group) -> plot by group directly.
    """
    probs = [tk["prob"] for tk in result.get("tokens", [])]
    if not probs:
        print("No token probabilities to plot.")
        return

    # Check if it's a list of lists
    is_grouped = isinstance(answer_token_indices, list) and answer_token_indices and isinstance(answer_token_indices[0], list)
    if not is_grouped:
        runs = _contiguous_runs(answer_token_indices or [])
    else:
        runs = [sorted(set(g)) for g in answer_token_indices if g]

    xs = list(range(len(probs)))
    plt.figure()

    # Blue baseline for all tokens
    plt.plot(xs, probs, color="blue", linestyle="-", marker=None, label="All tokens")

    # Red lines on the answer spans (connect within each span only)
    any_span = False
    for run in runs:
        if not run:
            continue
        ry = [probs[i] for i in run]
        plt.plot(run, ry, color="red", linestyle="-", marker="o", zorder=3)
        any_span = True

    if any_span:
        plt.plot([], [], color="red", linestyle="-", marker="o", label="Answer spans", zorder=3)

    plt.title(title)
    plt.xlabel("Token Index")
    plt.ylabel("Probability")
    plt.legend()
    plt.tight_layout()
    plt.show()



def find_all_literal_spans(full_text: str, span: str, case_insensitive: bool = True) -> List[Tuple[int,int]]:
    """
    Find all occurrences (character ranges) of `span` in `full_text`, optionally case-insensitive.
    Returns a list of (start_char, end_char).
    """
    out: List[Tuple[int,int]] = []
    if not span:
        return out
    hay = full_text if not case_insensitive else full_text.lower()
    needle = span if not case_insensitive else span.lower()
    i = 0
    while True:
        pos = hay.find(needle, i)
        if pos == -1:
            break
        out.append((pos, pos + len(needle)))
        i = pos + 1
    return out


def char_ranges_to_token_indices(tokens: List[Dict], char_ranges: Iterable[Tuple[int,int]]) -> List[int]:
    """
    Merge multiple character ranges and map them to a set of token indices (deduplicated and sorted).
    """
    hits: List[int] = []
    cursor = 0
    for idx, tk in enumerate(tokens):
        t = tk["token"]
        t_start = cursor
        t_end = cursor + len(t)
        for (s, e) in char_ranges:
            if not (t_end <= s or t_start >= e):
                hits.append(idx)
                break
        cursor = t_end
    return sorted(set(hits))


def collect_answer_index_groups(
    result: dict,
    case_insensitive_ok: bool = True,
    is_longest: bool = True,
) -> tuple[str | None, list[list[int]], list[str]]:
    """
    Match only the final answer span(s).
    Returns:
      final_answer: str | None
      index_groups: [[indices for span1], [indices for span2], ...]
      similar_spans: []  # always empty in this function
    """
    full_text = result.get("text", "") or ""
    tokens = result.get("tokens", []) or []
    if not full_text or not tokens:
        return None, [], []

    final_answer = extract_final_span(full_text)
    if not final_answer:
        return None, [], []

    # Find all literal occurrences of the final answer
    ranges = find_all_literal_spans(full_text, final_answer, case_insensitive=case_insensitive_ok)
    ranges.sort(key=lambda r: r[0])

    groups: list[list[int]] = []
    for cr in ranges:
        idxs = _char_range_to_token_indices(tokens, cr)
        if idxs:
            groups.append(idxs)

    # Sort groups globally
    groups.sort(key=lambda g: g[0])

    # Keep longest or shortest non-redundant segments
    if is_longest:
        groups = _keep_longest_nonredundant_groups(groups)
    else:
        groups = _keep_shortest_nonredundant_groups(groups)

    return final_answer, groups, []



def collect_answer_index_groups_with_similars(
    result: dict,
    llm_model: str = "gpt-4.1-mini",
    max_similar: int = 2,
    case_insensitive_ok: bool = True,
    is_longest: bool =  True,
) -> tuple[str | None, list[list[int]], list[str]]:
    """
    Returns:
      final_answer,
      index_groups: [[indices for span1], [indices for span2], ...]  # each occurrence as one group
      similar_spans: list[str]
    """
    full_text = result.get("text", "") or ""
    tokens = result.get("tokens", []) or []
    if not full_text or not tokens:
        return None, [], []

    final_answer = extract_final_span(full_text)
    if not final_answer:
        return None, [], []

    # Use an LLM to find similar substrings that also appear in the sequence (not exact matches)
    similar_spans = llm_find_similar_spans_in_sequence(
        sequence_text=full_text,
        final_answer=final_answer,
        model=llm_model,
        max_return=max_similar,
        case_insensitive_ok=case_insensitive_ok,
    )

    # For the final answer + each similar span, find all occurrences and map to token index groups
    groups: list[list[int]] = []
    def _find_all(span: str) -> list[tuple[int,int]]:
        return find_all_literal_spans(full_text, span, case_insensitive=case_insensitive_ok)

    # Order: all final-answer occurrences first, then similar-span occurrences (each sorted by position)
    for span_text in [final_answer] + similar_spans:
        ranges = _find_all(span_text)
        ranges.sort(key=lambda r: r[0])
        for cr in ranges:
            idxs = _char_range_to_token_indices(tokens, cr)
            if idxs:
                groups.append(idxs)

    # Sort groups by the first index for consistent plotting
    groups.sort(key=lambda g: g[0])

    # Keep non-redundant segments (longest by default)
    if is_longest:
        groups = _keep_longest_nonredundant_groups(groups)
    else:
        groups = _keep_shortest_nonredundant_groups(groups)

    return final_answer, groups, similar_spans


from typing import List, Tuple, Dict, Any

def collect_answer_index_groups_with_prev(
    obj: Dict[str, Any],
    case_insensitive_ok: bool = True,
) -> Tuple[List[List[int]], float]:
    """
    Based on literal matches to `obj['excluded']`, return:
      - merged_groups: [[token_idx...], [token_idx...], ...]   # merged groups across all excluded spans
      - prev_avg:      float                                   # mean of per-group average token probabilities

    Accepts either:
      - sample: { "raw": {"text":..., "tokens":[...]}, "excluded":[...] }
      - result: { "text":..., "tokens":[...], "excluded":[...] }
    """
    raw = (obj.get("raw") or {}) if isinstance(obj, dict) else {}
    full_text = (obj.get("text") or raw.get("text") or "") or ""
    tokens    = (obj.get("tokens") or raw.get("tokens") or []) or []
    excluded  = (obj.get("excluded") or raw.get("excluded") or []) or []

    if not full_text or not tokens or not excluded:
        return [], 0.0

    T = len(tokens)

    def _avg_prob(group: List[int]) -> float:
        ps = [tokens[i].get("prob", 0.0) for i in group if 0 <= i < T]
        return (sum(ps) / len(ps)) if ps else 0.0

    def _find_all(span: str) -> List[Tuple[int, int]]:
        # Literal matches (optionally case-insensitive)
        return find_all_literal_spans(full_text, span, case_insensitive=case_insensitive_ok)

    # Deduplicate while preserving order
    seen, ordered_excluded = set(), []
    for s in excluded:
        s = "" if s is None else str(s)
        if s and s not in seen:
            seen.add(s)
            ordered_excluded.append(s)

    # Collect and merge all groups for the excluded spans
    merged_groups: List[List[int]] = []
    for span_text in ordered_excluded:
        ranges = _find_all(span_text)
        if not ranges:
            continue
        ranges.sort(key=lambda r: r[0])
        for cr in ranges:
            idxs = _char_range_to_token_indices(tokens, cr)
            if idxs:
                merged_groups.append(idxs)

    if not merged_groups:
        return [], 0.0

    # Sort and remove redundancy (keep longest non-redundant groups)
    merged_groups.sort(key=lambda g: g[0])
    merged_groups = _keep_longest_nonredundant_groups(merged_groups)

    # Compute mean confidence: average per-group means
    per_group_means = [_avg_prob(g) for g in merged_groups]
    prev_avg = sum(per_group_means) / len(per_group_means) if per_group_means else 0.0

    return merged_groups, prev_avg


def _char_range_to_token_indices(tokens: list[dict], char_range: tuple[int, int]) -> list[int]:
    """Map a (start_char, end_char) range to a list of token indices for that span (do not merge with others)."""
    start_char, end_char = char_range
    hit = []
    cursor = 0
    for i, tk in enumerate(tokens):
        t = tk["token"]
        t_start, t_end = cursor, cursor + len(t)
        if not (t_end <= start_char or t_start >= end_char):
            hit.append(i)
        cursor = t_end
    return hit



def _keep_longest_nonredundant_groups(groups: list[list[int]]) -> list[list[int]]:
    """
    Keep the longest segments:
      - If A is a subset of (or equal to) B, keep B (the longer/non-shorter one).
      - If they only overlap without subset relation, keep both.
      - Also deduplicate and normalize each group's indices.
    """
    # Normalize: deduplicate and sort each group
    norm = [sorted(set(g)) for g in groups if g]

    # Sort by length (desc), then start index, then lexicographically, so longer groups are kept first
    norm.sort(key=lambda g: (-len(g), g[0], g))

    kept: list[list[int]] = []
    kept_sets: list[set[int]] = []

    for g in norm:
        gs = set(g)
        # Drop if current group is a subset of any kept group (or equal)
        if any(gs.issubset(ks) for ks in kept_sets):
            continue
        kept.append(g)
        kept_sets.append(gs)

    # Final order: by start index asc, length desc
    kept.sort(key=lambda g: (g[0], -len(g)))
    return kept


def _keep_shortest_nonredundant_groups(groups: list[list[int]]) -> list[list[int]]:
    """
    Keep the shortest segments:
      - If A is a superset of (or equal to) any kept group, drop A (prefer the shorter one).
      - If they only overlap without subset/superset relation, keep both.
      - Also deduplicate and normalize each group's indices.
    """
    # Normalize: deduplicate and sort each group
    norm = [sorted(set(g)) for g in groups if g]

    # Sort by length (asc), then start index, then lexicographically, so shorter groups enter `kept` first
    norm.sort(key=lambda g: (len(g), g[0], g))

    kept: list[list[int]] = []
    kept_sets: list[set[int]] = []

    for g in norm:
        gs = set(g)
        # Drop current if it is a superset (or equal) of any kept group
        if any(gs.issuperset(ks) for ks in kept_sets):
            continue
        kept.append(g)
        kept_sets.append(gs)

    # Final order: by start index asc, length asc
    kept.sort(key=lambda g: (g[0], len(g)))
    return kept
