from typing import List, Optional

from src.generation_utils import (
    extract_alternative_paths,
    extract_context,
    extract_equivalent_classes,
    self_complete,
    verify_correctness_pairwise,
)
from src.global_edit_utils import clean_up_text
from src.text_poa_graph import TextPOAGraph

"""
Decodes from a TextPOAGraph object to a string by sequentially selecting nodes based on the selection threshold.
Only the primary variation of selected variable nodes are selected.
Text is edited using the global_edit_function (e.g. to clean up text by removing incoherencies, disfluencies, and redundancies).

Args:
    text_poa_graph: The TextPOAGraph object to decode.
    selection_threshold: The threshold for selecting nodes.
    model: The model to use for decoding.

Returns:
    A string of the decoded text.
"""


def decode_consensus(
    text_poa_graph: TextPOAGraph,
    selection_threshold: Optional[float] = 0.5,
    task: str = "bio",
    **kwargs,
) -> str:
    if text_poa_graph.failed:
        return "Abstain"

    text_poa_graph.toposort()

    consensus_node_ids = text_poa_graph.consensus_node_ids

    selected_node_ids = []

    for node_id in consensus_node_ids:
        if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
            continue

        selected_node_ids.append(node_id)

        for neighbor_id in text_poa_graph.nodedict[node_id].outEdges:
            if neighbor_id in consensus_node_ids:
                continue

            if (
                len(text_poa_graph.nodedict[neighbor_id].labels) / text_poa_graph.num_sequences
                >= selection_threshold
            ):
                selected_node_ids.append(neighbor_id)

    texts = []
    for node_id in selected_node_ids:
        if not text_poa_graph.nodedict[node_id].variations:
            texts.append(text_poa_graph.nodedict[node_id].text)
        else:
            all_texts = [v for v in text_poa_graph.nodedict[node_id].variations.values()]
            all_texts.append(text_poa_graph.nodedict[node_id].text)
            # select the variation that is longest
            texts.append(max(all_texts, key=len))
    text = " ".join(texts)
    edited_text = clean_up_text(text=text, task=task, api="openai", **kwargs)
    return edited_text


def decode_self_verified(
    text_poa_graph: TextPOAGraph,
    problem: str,
    uncertainty_threshold: float = 0.6,
    verification_api: str = "openai",
    verification_model: str = "gpt-4o-mini",
    grace_period: bool = True,
):
    high_uncertainty_nodes = []
    for node_id in text_poa_graph.consensus_node_ids:
        if node_id == text_poa_graph.start_id or node_id == text_poa_graph.end_id:
            continue

        outgoing_edges = text_poa_graph.nodedict[node_id].outEdges
        branching_factor = len(outgoing_edges) / text_poa_graph.num_sequences

        if branching_factor > uncertainty_threshold:
            high_uncertainty_nodes.append(node_id)

    selected_labels = list(text_poa_graph._seq_paths.keys())
    masked_candidates = {}
    uncertain_region = False
    for label in selected_labels:
        text = ""
        for node_id in text_poa_graph._seq_paths[label]:
            if uncertain_region:
                text += f" *START_SEPARATOR*_{node_id} "
            if node_id in high_uncertainty_nodes:
                uncertain_region = True

            if len(text_poa_graph.nodedict[node_id].variations) > 0:
                text += text_poa_graph.nodedict[node_id].variations[label]
                text += " "
            else:
                text += text_poa_graph.nodedict[node_id].text
                text += " "

            if uncertain_region and node_id not in high_uncertainty_nodes:
                text += f" *END_SEPARATOR*_{node_id} "
                uncertain_region = False
        masked_candidates[label] = text

    patch_start_node = None
    uncertain_ids = []

    # give a grace period for the first incorrect step
    prev_step = {label: None for label in selected_labels}

    for node_id in high_uncertainty_nodes:
        uncertain_ids.append(node_id)
        context_before = extract_context(text_poa_graph, node_id)
        alternative_paths = extract_alternative_paths(text_poa_graph, node_id)
        equivalent_classes = extract_equivalent_classes(text_poa_graph, node_id, selected_labels)
        new_labels = selected_labels.copy()

        # Only do self-verifaction for labels from different sematically equivalent branches
        if len(equivalent_classes) <= 1:
            continue
        i = 0
        while i < len(equivalent_classes):
            if i + 1 < len(equivalent_classes):
                label_a = equivalent_classes[i][0]
                label_b = equivalent_classes[i + 1][0]
                full_a = context_before[label_a] + alternative_paths[label_a]
                full_b = context_before[label_b] + alternative_paths[label_b]

                score = verify_correctness_pairwise(
                    full_text_1=full_a,
                    full_text_2=full_b,
                    verification_model=verification_model,
                    problem=problem,
                    api=verification_api,
                )
                if float(score[0]) < 1.0:
                    print(f"Label {label_a} is incorrect at node {node_id}")
                    masked_candidates[label_a] = (
                        masked_candidates[label_a]
                        .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
                        .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
                    )
                    if not prev_step[label_a]:
                        prev_step[label_a] = True
                    if prev_step[label_a] and grace_period or not grace_period:
                        for label_i in equivalent_classes[i]:
                            new_labels.remove(label_i)
                            print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
                if float(score[0]) == 1.0:
                    prev_step[label_a] = False
                if float(score[1]) < 1.0:
                    print(f"Label {label_b} is incorrect at node {node_id}")
                    masked_candidates[label_b] = (
                        masked_candidates[label_b]
                        .replace(f" *START_SEPARATOR*_{node_id} ", "*START_POSSIBLE_ERROR*")
                        .replace(f" *END_SEPARATOR*_{node_id} ", "*END_POSSIBLE_ERROR*")
                    )
                    if not prev_step[label_b]:
                        prev_step[label_b] = True
                    if prev_step[label_b] and grace_period or not grace_period:
                        for label_i in equivalent_classes[i + 1]:
                            new_labels.remove(label_i)
                            print(f"\nSequence {label_i} pruned at node {node_id} (pairwise)")
                if float(score[1]) == 1.0:
                    prev_step[label_b] = False
                i += 2
            else:
                break

        if len(new_labels) == 0:
            patch_start_node = node_id
            break

        selected_labels = new_labels.copy()

    # These are the pruned approaches with masking
    print(masked_candidates)
    masked_approaches = "\n".join(
        [
            f"Approach {label}: {masked_candidates[label].replace('START_SEPARATOR', 'START_UNCERTAIN_REGION').replace('END_SEPARATOR', 'END_UNCERTAIN_REGION')}"
            for label in selected_labels
        ]
    )
    # These are all approaches with masking
    all_approaches = "\n".join(
        [f"Approach {label}: {masked_candidates[label]}" for label in masked_candidates.keys()]
    )

    default_prompt = f"""
    Solve the following math problem with mathematical precision and clarity.

    Problem: {problem}

    Below are potential solution approaches with sections marked as uncertain (between *START_UNCERTAIN_REGION* and *END_UNCERTAIN_REGION*). 
    These sections may contain conceptual or computational errors.

    There are also sections marked as *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR*.
    A verification step indicated that these steps are highly likely to contain errors.

    Potential Approaches:
    {masked_approaches}

    Your task:
    1. Analyze all potential approaches critically, identifying their mathematical strengths and weaknesses
       If the approaches contain different answers, think carefully about why they are different, and use this to identify potential errors.
    2. Using the sections with special markers, identify potential errors.
    3. Develop a rigorous, step-by-step solution based on sound mathematical principles
    4. For uncertain regions:
       - Verify each step using algebraic or numerical validation
       - If correct, incorporate these steps with appropriate justification
       - If incorrect, provide clear corrections with mathematical reasoning for your changes
    5. Follow a comparative approach, using the differences between approaches to identify potential errors.
    6. Do not blindly follow the approaches, but rather use them to identify potential errors.

    Guidelines for your solution:
    - Begin with a strategic overview of your chosen approach
    - Present each mathematical step with clear notation and justification
    - Pay special attention to areas that were previously marked uncertain

    Conclude your solution with:
    Therefore, the final answer is: $\\boxed{{answer}}$.

    Solution:
    """

    patch_prompt = f"""
    Solve the following mathematical problem with precision and clarity.

    Problem: {problem}

    You have been provided with several partial solution approaches that attempted to solve this problem. 
    None of these approaches are correct, but may contain valuable insights.
    Sections marked between *START_POSSIBLE_ERROR* and *END_POSSIBLE_ERROR* indicate steps where previous solutions showed uncertainty.
    A verification step indicated that these steps are likely to contain errors.

    INSTRUCTIONS:
    1. Synthesize a correct solution using insights from the previous approaches
    2. Pay special attention to fixing the problematic areas marked by separators
    3. Develop your solution step-by-step, showing clear mathematical reasoning
    4. Focus especially on mathematical correctness in areas where previous solutions diverged
    5. Present your work in a logical, sequential manner suitable for an advanced reader

    GUIDELINES FOR MATHEMATICAL RIGOR:
    1. MAINTAIN MATHEMATICAL RIGOR
    - Verify that all mathematical operations follow from established principles and definitions
    - Ensure dimensional consistency throughout calculations
    - Check that algebraic manipulations preserve equality and do not introduce errors
   
    2. CONSIDER ALTERNATIVE PERSPECTIVES
    - Even when approaches reach the same conclusion, examine their reasoning independently
    - Look for more elegant or insightful connections that may be missed across all approaches
    - Consider whether fundamental mathematical principles suggest a different path
   
    3. CRITICAL VALIDATION
    - Test conclusions using known mathematical properties and relationships
    - When possible, verify results using alternative methods
    - Be especially cautious when all approaches agree on a result but use similar reasoning
   
    4. USE PRECISION IN CORRECTIONS
    - When correcting uncertain regions, specify exactly what was incorrect and why
    - Provide clear mathematical justification for any changes
    - Ensure corrections align with standard mathematical principles and notations

    Previous Approaches (for reference only):
{all_approaches}

Your Solution:
[Begin with a clear statement of your approach]
[Provide detailed mathematical steps]
[Ensure correct handling of complex mathematical operations]
[Verify your work at key points, especially in previously problematic areas]

Always conclude with:
Therefore, the final answer is: $\\boxed{{answer}}$ 
    """

    if patch_start_node is not None or len(masked_candidates.keys()) == 1:
        print("None correct, patching")
        prompt = patch_prompt
    else:
        prompt = default_prompt

    return self_complete(
        verification_prompt=prompt, verification_model=verification_model, api=verification_api
    ), masked_candidates
