# %%

import asyncio
import concurrent.futures

from dotenv import load_dotenv
from prompts_and_parse_functions import parse_cot_response

from latent_reasoning_latents.concept_pipeline.bias_tester import BiasTester
from latent_reasoning_latents.concept_pipeline.concept_pipeline import ConceptPipeline
from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset import (
    ConceptPipelineDataset,
)
from latent_reasoning_latents.concept_pipeline.responses_generator import (
    ResponsesGenerator,
)
from latent_reasoning_latents.concept_pipeline.verbalization_detector import (
    VerbalizationDetector,
)
from latent_reasoning_latents.util import UNIVERSITY_ADMISSION_RESULTS_PATH

load_dotenv()

# %%
# model_name = "google/gemma-3-27b-it"
# model_name = "qwen/qwq-32b"
# model_name = "google/gemini-2.5-flash"
model_name = "anthropic/claude-sonnet-4-20250514"

dataset_name = "university_admission"
experiment_key = f"{dataset_name}_{model_name.split('/')[-1]}"
debug_n_concepts = None  # 5
debug_n_inputs = None  # 200

# %%

# Load dataset
output_dir = UNIVERSITY_ADMISSION_RESULTS_PATH
dataset = ConceptPipelineDataset.load_by_name(dataset_name, output_dir)
assert isinstance(dataset, ConceptPipelineDataset)

# %%

# 1. Instantiate pipeline components
responses_generator = ResponsesGenerator(
    model_name=model_name,
    local_generation=False,
    use_anthropic_batches=model_name.startswith("anthropic"),
)

verbalization_detector = VerbalizationDetector(use_openai_batches=True)

# Bias tester using admission rate as acceptance (1=admit, 0=reject)
varying_input_param_name = "student_application"
bias_tester = BiasTester(
    responses_generator=responses_generator,
    parse_response_fn=parse_cot_response,
    variating_input_parameter=varying_input_param_name,
)


def representative_inputs_k_per_stage_index_fn(stage_index: int) -> int:
    # We start at 20, then double with each stage.
    if stage_index == 0:
        return 20
    return 2 * representative_inputs_k_per_stage_index_fn(stage_index - 1)


# %%

# 2. Instantiate the main pipeline
pipeline = ConceptPipeline(
    dataset=dataset,
    responses_generator=responses_generator,
    verbalization_detector=verbalization_detector,
    bias_tester=bias_tester,
    representative_inputs_k_per_stage_index_fn=representative_inputs_k_per_stage_index_fn,
    output_dir=output_dir,
    parsed_labels_mapping={1: "Admitted", 0: "Rejected"},
)

# %%

# 3. Run the pipeline


try:
    # Try to get existing event loop (e.g. in Jupyter)
    loop = asyncio.get_running_loop()

    def _thread_runner():
        return asyncio.run(
            pipeline.run(
                experiment_key=experiment_key,
                debug_n_concepts=debug_n_concepts,
                debug_n_inputs=debug_n_inputs,
            )
        )

    # Run the async function in a thread to avoid "already running loop" errors
    with concurrent.futures.ThreadPoolExecutor() as ex:
        fut = ex.submit(_thread_runner)
        pipeline_result = fut.result()
except RuntimeError as e:
    import traceback

    traceback.print_exc()
    print(f"RuntimeError: {e}")
    if "no running event loop" in str(e):
        # No running event loop – safe to run directly.
        pipeline_result = asyncio.run(
            pipeline.run(
                experiment_key=experiment_key,
                debug_n_concepts=debug_n_concepts,
                debug_n_inputs=debug_n_inputs,
            )
        )

# %%
