# %%

import asyncio
import concurrent.futures

from dotenv import load_dotenv
from prompts_and_parse_functions import parse_cot_response

from latent_reasoning_latents.concept_pipeline.bias_tester import BiasTester
from latent_reasoning_latents.concept_pipeline.concept_pipeline import ConceptPipeline
from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset import (
    ConceptPipelineDataset,
)
from latent_reasoning_latents.concept_pipeline.responses_generator import (
    ResponsesGenerator,
)
from latent_reasoning_latents.concept_pipeline.verbalization_detector import (
    VerbalizationDetector,
)
from latent_reasoning_latents.util import LOAN_APPROVAL_RESULTS_PATH

load_dotenv()

# %%
model_name = "google/gemma-3-12b-it"
# model_name = "google/gemma-3-27b-it"
# model_name = "google/gemini-2.5-flash"
# model_name = "mistralai/mistral-small-3.1-24b-instruct"
# model_name = "qwen/qwq-32b"
# model_name = "anthropic/claude-sonnet-4-20250514"

dataset_name = "loan_approval"
debug_n_concepts = None  # 5
debug_n_inputs = None  # 200

# Number of seeds to run
num_seeds = 3

# %%

# Load dataset
output_dir = LOAN_APPROVAL_RESULTS_PATH
dataset = ConceptPipelineDataset.load_by_name(dataset_name, output_dir)
assert isinstance(dataset, ConceptPipelineDataset)

# %%


def representative_inputs_k_per_stage_index_fn(stage_index: int) -> int:
    # We start at 20, then double with each stage.
    if stage_index == 0:
        return 20
    return 2 * representative_inputs_k_per_stage_index_fn(stage_index - 1)


def run_pipeline_for_seed(seed: int):
    """Run the full pipeline for a given seed, recording results as a separate model."""
    # Create experiment key with seed suffix
    model_suffix = model_name.split("/")[-1]
    experiment_key = f"{dataset_name}_{model_suffix}-seed-{seed}"
    print(f"\n{'=' * 60}")
    print(f"Running pipeline for seed {seed}: {experiment_key}")
    print(f"{'=' * 60}\n")

    # 1. Instantiate pipeline components (fresh for each seed)
    responses_generator = ResponsesGenerator(
        model_name=model_name,
        local_generation=False,
        use_anthropic_batches=model_name.startswith("anthropic"),
        use_openai_batches=model_name.startswith("openai"),
    )

    verbalization_detector = VerbalizationDetector(use_openai_batches=True)

    # Bias tester using approval rate as acceptance (1=approve, 0=reject)
    varying_input_param_name = "loan_application"
    bias_tester = BiasTester(
        responses_generator=responses_generator,
        parse_response_fn=parse_cot_response,
        variating_input_parameter=varying_input_param_name,
    )
    assert isinstance(dataset, ConceptPipelineDataset)

    # 2. Instantiate the main pipeline
    pipeline = ConceptPipeline(
        dataset=dataset,
        responses_generator=responses_generator,
        verbalization_detector=verbalization_detector,
        bias_tester=bias_tester,
        representative_inputs_k_per_stage_index_fn=representative_inputs_k_per_stage_index_fn,
        output_dir=output_dir,
        parsed_labels_mapping={1: "Approved", 0: "Rejected"},
    )

    # 3. Run the pipeline
    try:
        # Try to get existing event loop (e.g. in Jupyter)
        asyncio.get_running_loop()

        def _thread_runner():
            return asyncio.run(
                pipeline.run(
                    experiment_key=experiment_key,
                    debug_n_concepts=debug_n_concepts,
                    debug_n_inputs=debug_n_inputs,
                )
            )

        # Run the async function in a thread to avoid "already running loop" errors
        with concurrent.futures.ThreadPoolExecutor() as ex:
            fut = ex.submit(_thread_runner)
            pipeline_result = fut.result()
    except RuntimeError as e:
        import traceback

        traceback.print_exc()
        print(f"RuntimeError: {e}")
        if "no running event loop" in str(e):
            # No running event loop – safe to run directly.
            pipeline_result = asyncio.run(
                pipeline.run(
                    experiment_key=experiment_key,
                    debug_n_concepts=debug_n_concepts,
                    debug_n_inputs=debug_n_inputs,
                )
            )

    return pipeline_result


# %%

# Run the pipeline for each seed
all_results = {}
for seed in range(num_seeds):
    result = run_pipeline_for_seed(seed)
    all_results[seed] = result

print(f"\n{'=' * 60}")
print(f"Completed all {num_seeds} seeds!")
print("Results saved with experiment keys:")
for seed in range(num_seeds):
    model_suffix = model_name.split("/")[-1]
    print(f"  - {dataset_name}_{model_suffix}-seed-{seed}")
print(f"{'=' * 60}")

# %%
