from __future__ import annotations

import os
from pathlib import Path

from latent_reasoning_latents.concept_pipeline.concept_id import ConceptId
from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset import (
    ConceptPipelineDataset,
)
from latent_reasoning_latents.concept_pipeline.concept_pipeline_result import (
    ConceptPipelineResult,
    StageResults,
)


def compute_flipped_pair_verbalization_counts(
    stage: StageResults,
    concept_id: ConceptId,
) -> tuple[int, int]:
    assert stage.variation_bias_results is not None
    assert stage.concept_verbalization_on_variation_responses is not None
    bias_result = stage.variation_bias_results[concept_id]
    verbalization_map = stage.concept_verbalization_on_variation_responses[concept_id]

    verbalized_pairs = 0
    total_pairs = 0

    for input_idx, pair_map in bias_result.responses_by_input.items():
        if input_idx not in verbalization_map:
            continue
        for pair_id, pair_responses in pair_map.items():
            if pair_id not in verbalization_map[input_idx]:
                continue
            if not pair_responses.has_flipped_acceptance():
                continue
            total_pairs += 1
            pair_verbalization = verbalization_map[input_idx][pair_id]
            positive_flags = (
                result_obj.verbalized
                for result_obj in pair_verbalization.positive_variation_responses_verbalizations.values()
            )
            negative_flags = (
                result_obj.verbalized
                for result_obj in pair_verbalization.negative_variation_responses_verbalizations.values()
            )
            if any(positive_flags) or any(negative_flags):
                verbalized_pairs += 1

    return verbalized_pairs, total_pairs


def get_unfaithful_concepts(
    dataset: ConceptPipelineDataset,
    result: ConceptPipelineResult,
    stage: StageResults,
) -> list[ConceptId]:
    """Return concepts that pass verbalization checks in the current stage.

    Logic for each significant concept in the current stage:
    - Skip if baseline verbalization ratio >= baseline threshold
    - Check variation verbalization ratio for that specific side
    - If variation ratio < threshold, mark as unfaithful
    """
    assert result.stages is not None and len(result.stages) > 0
    pipeline_concepts = dataset.get_pipeline_concepts()
    assert len(pipeline_concepts) > 0
    assert result.variations_verbalization_threshold is not None

    current_stage = stage
    assert current_stage.significant_concepts is not None
    assert current_stage.concept_verbalization_on_variation_responses is not None
    assert current_stage.concept_ids_unverbalized_on_baseline is not None

    baseline_unverbalized = set(current_stage.concept_ids_unverbalized_on_baseline)

    unfaithful: list[ConceptId] = []

    for concept_id in current_stage.significant_concepts:
        # If a concept was verbalized on baseline (failed the baseline filter),
        # skip it from variation verbalization filtering.
        if concept_id not in baseline_unverbalized:
            continue

        assert concept_id in current_stage.concept_verbalization_on_variation_responses

        verbalized_pairs, total_pairs = compute_flipped_pair_verbalization_counts(
            current_stage,
            concept_id,
        )
        if total_pairs == 0:
            pair_ratio = 0.0
        else:
            pair_ratio = verbalized_pairs / total_pairs

        # A concept is considered "verbalized on variations" iff the fraction of flipped pairs
        # where either side is verbalized meets the threshold. Otherwise, it is unfaithful.
        if pair_ratio < result.variations_verbalization_threshold:
            unfaithful.append(concept_id)

    unfaithful.sort()
    return unfaithful


def build_unfaithful_concepts_html(
    figures_root_directory: str | Path,
    dataset: ConceptPipelineDataset,
    result: ConceptPipelineResult,
) -> None:
    """Build a summary HTML file named unfaithful_concepts.html under figures root.

    Shows per-(concept, side) details: baseline verbalization ratio, per-side p-value
    (vs baseline), per-side verbalization ratio, action, and responses link.
    """
    assert isinstance(figures_root_directory, str | Path)
    pipeline_concepts = dataset.get_pipeline_concepts()
    assert len(pipeline_concepts) > 0
    assert result.stages is not None and len(result.stages) > 0

    current_stage = result.stages[-1]
    assert result.significant_unfaithful_concepts is not None
    unfaithful_concepts = result.significant_unfaithful_concepts

    root = str(figures_root_directory)
    os.makedirs(root, exist_ok=True)

    # Early return: if there are no unfaithful concepts (e.g., no significant sides),
    # write a minimal HTML and exit without requiring later-stage artifacts.
    if len(unfaithful_concepts) == 0:
        html = "\n".join(
            [
                "<h1>Unfaithful Concepts</h1>",
                "<p>No unfaithful concepts detected.</p>",
            ]
        )
        out_path = os.path.join(root, "unfaithful_concepts.html")
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(html)
        return

    # From here on, we expect all downstream artifacts to exist
    assert current_stage.concept_verbalization_on_baseline_responses is not None
    assert current_stage.concept_verbalization_on_variation_responses is not None
    assert current_stage.variation_bias_results is not None

    concepts_by_id = {c.id: c for c in pipeline_concepts}

    parts: list[str] = []
    parts.append("<h1>Unfaithful Concepts</h1>")

    acc = result.get_baseline_acceptances_list()
    assert isinstance(acc, list) and len(acc) > 0
    valid = [a for a in acc if a is not None]
    assert len(valid) > 0
    accepted = sum(1 for a in valid if a == 1)
    valid_n = len(valid)
    invalid_n = len(acc) - valid_n
    acc_rate = accepted / valid_n
    parts.append("<h2>Baseline acceptance</h2>")
    parts.append(
        f"<p>Acceptance rate: {acc_rate:.3f} ({accepted}/{valid_n}) &nbsp; | &nbsp; invalid: {invalid_n}</p>"
    )

    if len(unfaithful_concepts) == 0:
        parts.append("<p>No unfaithful concepts detected.</p>")
    else:
        for cid in sorted(unfaithful_concepts):
            assert cid in concepts_by_id
            concept = concepts_by_id[cid]
            title = concept.title
            sanitized_title = title.replace(" ", "_").replace("/", "__")
            parts.append(f"<h2>Concept {cid}: {title}</h2>")

            parts.append("<h3>Verbalization check guidelines</h3>")
            parts.append(
                '<pre style="white-space: pre-wrap; word-break: break-word; overflow-wrap: anywhere;">'
                + concept.verbalization_check_guide
                + "</pre>"
            )

            assert cid in current_stage.concept_verbalization_on_baseline_responses
            base_flags_by_input = (
                current_stage.concept_verbalization_on_baseline_responses[cid]
            )
            assert (
                isinstance(base_flags_by_input, dict) and len(base_flags_by_input) > 0
            )
            base_flags_all: list[bool] = []
            for flags_map in base_flags_by_input.values():
                assert isinstance(flags_map, dict) and len(flags_map) > 0
                base_flags_all.extend(
                    result_obj.verbalized for result_obj in flags_map.values()
                )
            base_pos = sum(1 for v in base_flags_all if v)
            base_total = len(base_flags_all)
            base_ratio = base_pos / base_total

            parts.append("<h3>Baseline verbalization</h3>")
            parts.append(f"<p>Ratio: {base_ratio:.3f} ({base_pos}/{base_total})</p>")

            for side in ["positive", "negative"]:
                assert cid in current_stage.concept_verbalization_on_variation_responses
                groups = current_stage.concept_verbalization_on_variation_responses[cid]
                flags_all: list[bool] = []
                for per_input in groups.values():
                    for per_pair in per_input.values():
                        if side == "positive":
                            m = per_pair.positive_variation_responses_verbalizations
                        else:
                            m = per_pair.negative_variation_responses_verbalizations
                        flags_all.extend(
                            result_obj.verbalized for result_obj in m.values()
                        )
                assert len(flags_all) > 0
                pos = sum(1 for v in flags_all if v)
                total = len(flags_all)
                ratio = pos / total
                parts.append(f"<h3>{side.capitalize()} variations</h3>")
                action_text = (
                    concept.addition_action
                    if side == "positive"
                    else concept.removal_action
                )
                parts.append(
                    '<pre style="white-space: pre-wrap; word-break: break-word; overflow-wrap: anywhere;">'
                    + action_text
                    + "</pre>"
                )
                parts.append("<ul>")
                assert cid in current_stage.variation_bias_results
                res = current_stage.variation_bias_results[cid]
                stats_pos_vs_neg = res.statistics_positive_vs_negative
                if side == "positive":
                    sg = res.statistics_positive
                else:
                    sg = res.statistics_negative
                assert (
                    isinstance(stats_pos_vs_neg, dict) and "p_value" in stats_pos_vs_neg
                )
                p_val = stats_pos_vs_neg["p_value"]
                assert isinstance(p_val, float) or p_val is None
                if p_val is not None:
                    parts.append(
                        f"<li>P_value (positive vs negative): {p_val:.6g}</li>"
                    )
                else:
                    parts.append("<li>P_value (positive vs negative): N/A</li>")
                assert sg is not None
                ap = sg.get("acceptance_proportion")
                ip = sg.get("invalid_proportion")
                assert isinstance(ap, float)
                assert isinstance(ip, float)
                parts.append(f"<li>Acceptance rate: {ap:.3f}</li>")
                parts.append(f"<li>Invalid rate: {ip:.3f}</li>")
                parts.append(
                    f"<li>Verbalization ratio: {ratio:.3f} ({pos}/{total})</li>"
                )

                side_file = f"{side}_responses.html"
                concept_dir = f"{cid}-{sanitized_title}"
                link_path = f"./{concept_dir}/{side_file}"
                parts.append(
                    f'<li>Responses: <a href="{link_path}">{side_file}</a></li>'
                )
                parts.append("</ul>")

    html = "\n\n".join(parts)
    out_path = os.path.join(root, "unfaithful_concepts.html")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(html)
