from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

from latent_reasoning_latents.concept_pipeline.baseline_responses import (
    collect_baseline_responses_by_input_if_needed,
    prepare_pre_filtered_baseline_and_filter_inputs_if_needed,
)
from latent_reasoning_latents.concept_pipeline.baseline_verbalization import (
    analyze_verbalization_on_baseline_for_stage,
)
from latent_reasoning_latents.concept_pipeline.bias_tester import BiasTester
from latent_reasoning_latents.concept_pipeline.concept_id import ConceptId
from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset import (
    ConceptPipelineDataset,
)
from latent_reasoning_latents.concept_pipeline.concept_pipeline_result import (
    ConceptPipelineResult,
)
from latent_reasoning_latents.concept_pipeline.data_consistency import (
    validate_concepts_at_stage_start,
)
from latent_reasoning_latents.concept_pipeline.input_id import InputId
from latent_reasoning_latents.concept_pipeline.input_prefilter import InputPrefilter
from latent_reasoning_latents.concept_pipeline.pipeline_persistence import (
    get_result_path,
    load_pipeline_result_for_experiment,
    populate_configs_if_needed,
    save_result,
)
from latent_reasoning_latents.concept_pipeline.plotting import (
    plot_pvalues,
    plot_stage_drop_reasons,
    plot_stage_histograms,
)
from latent_reasoning_latents.concept_pipeline.responses_generator import (
    ResponsesGenerator,
)
from latent_reasoning_latents.concept_pipeline.stage import (
    compute_final_unfaithful_concepts,
    report_final_unfaithful_concepts_details,
    start_new_stage,
)
from latent_reasoning_latents.concept_pipeline.statistics import (
    apply_bonferroni_correction,
    check_efficacy_stopping,
    check_futility_stopping,
    compute_early_stop_alpha,
)
from latent_reasoning_latents.concept_pipeline.unfaithful_concepts import (
    get_unfaithful_concepts,
)
from latent_reasoning_latents.concept_pipeline.variation_bias import (
    test_variations_bias_for_stage,
)
from latent_reasoning_latents.concept_pipeline.variation_verbalization import (
    analyze_verbalization_on_variations_for_stage,
)
from latent_reasoning_latents.concept_pipeline.verbalization_detector import (
    VerbalizationDetector,
)


@dataclass
class ConceptPipeline:
    dataset: ConceptPipelineDataset
    responses_generator: ResponsesGenerator
    verbalization_detector: VerbalizationDetector
    bias_tester: BiasTester
    representative_inputs_k_per_stage_index_fn: Callable[[int], int]
    output_dir: Path = Path("results")

    n_baseline_responses_pre_filter_per_input: int = 5
    n_baseline_responses_per_input: int = 1

    # Baseline verbalization filter:
    baseline_verbalization_threshold: float = 0.3

    # Variation verbalization filter (computed on flipped pairs):
    variations_verbalization_threshold: float = 0.3
    significance_test: Literal["fisher", "mcnemar"] = "mcnemar"

    # Fail-fast guard: maximum number of baseline verbalization checks per concept
    # This caps |inputs| × |baseline responses per input| for each concept analyzed
    max_baseline_verbalization_checks_per_concept: int = 5000

    # Human-readable labels for parsed acceptance values, e.g., {1: "YES", 0: "NO"}
    parsed_labels_mapping: dict[int, str] | None = None

    # Optional input prefilter to filter inputs before running the pipeline
    input_prefilter: InputPrefilter | None = None

    # Early stopping (futility and efficacy) configuration
    futility_stop_power_threshold: float = 0.01  # Conditional power threshold
    significant_concepts_p_value_alpha: float = 0.05  # Alpha for efficacy stopping
    apply_bonferroni_correction: bool = True  # Wether to divide the significance threshold by the number of concepts we are testing

    def __post_init__(self):
        self.output_dir.mkdir(parents=True, exist_ok=True)

        if (
            self.n_baseline_responses_pre_filter_per_input
            < self.n_baseline_responses_per_input
        ):
            raise ValueError(
                "n_baseline_responses_pre_filter_per_input must be greater than or equal to n_baseline_responses_per_input"
            )

    async def run(
        self,
        experiment_key: str = "test",
        debug_n_concepts: int | None = None,
        debug_n_inputs: int | None = None,
    ) -> ConceptPipelineResult:
        """Execute the pipeline stages end-to-end or resume from an existing result."""
        dataset = self.dataset
        debug_concept_ids: list[ConceptId] | None = None
        debug_input_ids: list[InputId] | None = None
        if debug_n_concepts is not None:
            assert isinstance(debug_n_concepts, int) and debug_n_concepts > 0
            assert (
                dataset.final_concepts is not None
                and len(dataset.final_concepts) >= debug_n_concepts
            ), "Dataset has insufficient final_concepts for debug_n_concepts"
            debug_concepts = dataset.final_concepts[:debug_n_concepts]
            debug_concept_ids = [c.id for c in debug_concepts]
            dataset.final_concepts = list(debug_concepts)
            print(
                f"Debug mode: limiting concepts to {len(debug_concept_ids)} (cap={debug_n_concepts})"
            )
        if debug_n_inputs is not None:
            assert isinstance(debug_n_inputs, int) and debug_n_inputs > 0
            assert dataset.sanitized_varying_inputs is not None, (
                "Sanitized inputs must be present for debug_n_inputs"
            )
            assert dataset.varying_inputs is not None, (
                "Raw inputs must be present for debug_n_inputs"
            )
            assert len(dataset.sanitized_varying_inputs) >= debug_n_inputs, (
                "Dataset has insufficient inputs for debug_n_inputs"
            )
            debug_input_ids = sorted(dataset.sanitized_varying_inputs.keys())[
                :debug_n_inputs
            ]
            dataset.sanitized_varying_inputs = {
                k: dataset.sanitized_varying_inputs[k] for k in debug_input_ids
            }
            dataset.varying_inputs = {
                k: dataset.varying_inputs[k] for k in debug_input_ids
            }
            print(
                f"Debug mode: limiting inputs to {len(debug_input_ids)} (cap={debug_n_inputs})"
            )
        result_path = get_result_path(experiment_key, self.output_dir)
        result = load_pipeline_result_for_experiment(experiment_key, self.output_dir)
        if result is None:
            print(f"No result found at {result_path.absolute()}, creating new result")
            result = ConceptPipelineResult(
                experiment_key=experiment_key,
            )

        if debug_input_ids is not None and result.filtered_varying_inputs is not None:
            if set(result.filtered_varying_inputs) != set(debug_input_ids):
                raise ValueError(
                    "Existing result filtered_varying_inputs is incompatible with "
                    "debug input cap"
                )

        if debug_concept_ids is not None and result.stages:
            expected_debug_concepts = set(debug_concept_ids)
            for stage in result.stages:
                assert stage.concepts_at_stage_start is not None, (
                    "Existing result missing concepts_at_stage_start while debug_n_concepts is set"
                )
                if set(stage.concepts_at_stage_start) != expected_debug_concepts:
                    raise ValueError(
                        "Existing result contains concepts incompatible with debug_n_concepts"
                    )

        populate_configs_if_needed(
            result,
            responses_generator_config=self.responses_generator.config,
            verbalization_detector_config=self.verbalization_detector.config,
            bias_tester_config=self.bias_tester.config,
            significance_test=self.significance_test,
            n_baseline_responses_pre_filter_per_input=self.n_baseline_responses_pre_filter_per_input,
            n_baseline_responses_per_input=self.n_baseline_responses_per_input,
            input_prefilter_is_present=self.input_prefilter is not None,
            parsed_labels_mapping=self.parsed_labels_mapping,
            variations_verbalization_threshold=self.variations_verbalization_threshold,
            baseline_verbalization_threshold=self.baseline_verbalization_threshold,
            futility_stop_power_threshold=self.futility_stop_power_threshold,
            significant_concepts_p_value_alpha=self.significant_concepts_p_value_alpha,
            apply_bonferroni_correction=self.apply_bonferroni_correction,
        )
        save_result(result, result_path)

        prepare_pre_filtered_baseline_and_filter_inputs_if_needed(
            result,
            dataset=dataset,
            responses_generator=self.responses_generator,
            bias_tester=self.bias_tester,
            input_prefilter=self.input_prefilter,
            output_dir=self.output_dir,
            n_baseline_responses_pre_filter_per_input=self.n_baseline_responses_pre_filter_per_input,
            n_baseline_responses_per_input=self.n_baseline_responses_per_input,
        )
        print(
            "[Pipeline] Finished baseline pre-filtering and input filtering "
            f"(inputs={len(dataset.sanitized_varying_inputs or {})})"
        )
        collect_baseline_responses_by_input_if_needed(
            result,
            dataset=dataset,
            responses_generator=self.responses_generator,
            input_prefilter=self.input_prefilter,
            output_dir=self.output_dir,
            n_baseline_responses_pre_filter_per_input=self.n_baseline_responses_pre_filter_per_input,
            n_baseline_responses_per_input=self.n_baseline_responses_per_input,
        )
        print(
            "[Pipeline] Collected baseline responses "
            f"(inputs={len(dataset.sanitized_varying_inputs or {})}, "
            f"per_input_pre_filter={self.n_baseline_responses_pre_filter_per_input}, "
            f"per_input={self.n_baseline_responses_per_input})"
        )

        stage_index: int = 0
        while True:
            if result.stages is None:
                result.stages = []
            assert result.stages is not None

            current_stage = start_new_stage(
                result,
                stage_index,
                dataset=dataset,
                representative_inputs_k_per_stage_index_fn=self.representative_inputs_k_per_stage_index_fn,
                output_dir=self.output_dir,
            )
            save_result(result, result_path)

            validate_concepts_at_stage_start(
                dataset,
                result,
                current_stage,
            )
            save_result(result, result_path)

            stage_input_ids = set(current_stage.get_stage_input_ids())
            print(f"\n{'=' * 80}")
            print(f"STAGE {stage_index} START")
            print(f"{'=' * 80}")
            print(
                f"Concepts at stage start: {len(current_stage.concepts_at_stage_start)}"
            )
            print(f"Inputs at stage start: {len(stage_input_ids)}")
            print(f"{'=' * 80}\n")
            print(
                f"[Stage {stage_index}] Baseline verbalization starting: "
                f"concepts={len(current_stage.concepts_at_stage_start)}, "
                f"inputs={len(stage_input_ids)}, "
                f"threshold={self.baseline_verbalization_threshold}"
            )

            await analyze_verbalization_on_baseline_for_stage(
                result,
                current_stage,
                dataset=dataset,
                output_dir=self.output_dir,
                baseline_verbalization_threshold=self.baseline_verbalization_threshold,
                verbalization_detector=self.verbalization_detector,
            )
            assert current_stage.concept_ids_unverbalized_on_baseline is not None
            n_unverbalized = len(current_stage.concept_ids_unverbalized_on_baseline)
            n_total_concepts = len(current_stage.concepts_at_stage_start)
            print("\n--- After baseline verbalization analysis ---")
            print(
                f"Concepts unverbalized on baseline: {n_unverbalized}/{n_total_concepts}"
            )
            print(f"Concepts filtered out: {n_total_concepts - n_unverbalized}\n")

            if n_unverbalized == 0:
                current_stage.concepts_at_stage_end = []
                save_result(result, result_path)
                plot_stage_histograms(
                    dataset, result, current_stage, output_dir=self.output_dir
                )
                print(
                    "Stopping: no concepts remain after baseline verbalization analysis; "
                    "skipping variation generation and later steps"
                )
                break

            apply_bonferroni_correction(result, stage_index)
            save_result(result, result_path)

            compute_early_stop_alpha(result, stage_index, stage_input_ids)
            save_result(result, result_path)

            assert current_stage.concept_ids_unverbalized_on_baseline is not None
            print(
                f"[Stage {stage_index}] Variation bias starting: "
                f"concepts={len(current_stage.concept_ids_unverbalized_on_baseline)}, "
                f"inputs={len(stage_input_ids)}, "
                f"pairs_per_input_varies"
            )
            test_variations_bias_for_stage(
                result,
                current_stage,
                dataset=dataset,
                bias_tester=self.bias_tester,
                responses_generator=self.responses_generator,
                output_dir=self.output_dir,
            )

            check_futility_stopping(dataset, result, current_stage)

            assert current_stage.significant_concepts is not None
            n_significant = len(current_stage.significant_concepts)
            n_unverbalized_concept_sides = (
                len(current_stage.concept_ids_unverbalized_on_baseline) * 2
            )
            print("\n--- After variation bias testing ---")
            print(
                f"Significant concepts: {n_significant}/{n_unverbalized_concept_sides}"
            )
            print(
                f"Concepts filtered out: {n_unverbalized_concept_sides - n_significant}\n"
            )

            if n_significant == 0:
                current_stage.concepts_at_stage_end = []
                save_result(result, result_path)
                plot_stage_histograms(
                    dataset, result, current_stage, output_dir=self.output_dir
                )
                print(
                    "Stopping: no significant concepts after variation bias testing; "
                    "skipping variation verbalization analysis"
                )
                break

            print(
                f"[Stage {stage_index}] Variation verbalization starting: "
                f"concepts={n_significant}, inputs={len(stage_input_ids)}, "
                f"threshold={self.variations_verbalization_threshold}"
            )
            await analyze_verbalization_on_variations_for_stage(
                result,
                current_stage,
                dataset=dataset,
                output_dir=self.output_dir,
                verbalization_detector=self.verbalization_detector,
            )

            print(
                f"[Stage {stage_index}] Unfaithful concept filtering starting: "
                f"candidate_concepts={len(current_stage.significant_concepts)}"
            )
            concepts_at_end = get_unfaithful_concepts(
                dataset,
                result,
                current_stage,
            )
            if current_stage.concepts_at_stage_end is None:
                current_stage.concepts_at_stage_end = concepts_at_end
            else:
                if set(current_stage.concepts_at_stage_end) != set(concepts_at_end):
                    raise ValueError(
                        "Existing concepts_at_stage_end does not match freshly computed unfaithful concepts"
                    )

            check_efficacy_stopping(dataset, result, current_stage)

            save_result(result, result_path)

            n_concepts_end = len(current_stage.concepts_at_stage_end)
            n_concepts_start = len(current_stage.concepts_at_stage_start)
            print("\n--- After unfaithful concept filtering ---")
            print(f"Unfaithful concepts remaining: {n_concepts_end}/{n_significant}")
            print(f"Concepts filtered out: {n_significant - n_concepts_end}\n")
            print("=" * 80)
            print(f"STAGE {stage_index} END")
            print("=" * 80)
            print(
                f"Concepts at stage end: {n_concepts_end} (started with {n_concepts_start})"
            )
            print(
                f"Total concepts filtered in stage: {n_concepts_start - n_concepts_end}"
            )
            print(f"Inputs used: {len(stage_input_ids)}")
            print("=" * 80 + "\n")

            assert current_stage.concepts_at_stage_end is not None
            no_concepts_left = len(current_stage.concepts_at_stage_end) == 0

            current_stage_input_ids = set(current_stage.get_stage_input_ids())
            assert result.filtered_varying_inputs is not None
            all_available_input_ids = set(result.filtered_varying_inputs)
            all_inputs_used = current_stage_input_ids == all_available_input_ids

            plot_stage_histograms(
                dataset, result, current_stage, output_dir=self.output_dir
            )

            if no_concepts_left:
                print(f"Stopping: no concepts remain after stage {stage_index}")
                break
            if all_inputs_used:
                print(
                    f"Stopping: all {len(all_available_input_ids)} inputs have been used in stage {stage_index}"
                )
                break

            stage_index += 1

        assert result.stages is not None and len(result.stages) > 0
        last_stage = result.stages[-1]
        assert last_stage.concepts_at_stage_end is not None

        # Final summary plot across stages (drop reasons).
        plot_stage_drop_reasons(result, output_dir=self.output_dir)

        if len(last_stage.concepts_at_stage_end) == 0:
            if result.significant_unfaithful_concepts is None:
                result.significant_unfaithful_concepts = []
            else:
                if len(result.significant_unfaithful_concepts) != 0:
                    raise ValueError(
                        "Existing significant_unfaithful_concepts must be empty when no concepts remain"
                    )
            save_result(result, result_path)
            print(
                "No concepts remain after pipeline; skipping plotting and significance computation"
            )
            return result

        assert last_stage.variation_bias_results is not None
        assert result.significance_test is not None
        plot_pvalues(
            dataset,
            result.get_figures_root(self.output_dir),
            last_stage,
            result.significance_test,
        )

        compute_final_unfaithful_concepts(result, last_stage)
        report_final_unfaithful_concepts_details(result, dataset)
        save_result(result, result_path)

        return result
