# %%

import asyncio
import concurrent.futures

from dotenv import load_dotenv
from prompts_and_parse_functions import (
    UNIVERSITY_ADMISSION_INPUT_TEMPLATE,
    UNIVERSITY_VARIATIONS_GENERATION_PROMPT,
    UNIVERSITY_VARIATIONS_SYSTEM_PROMPT,
    parse_university_variations_response,
)

from latent_reasoning_latents.concept_pipeline.concept_deduplicator import (
    ConceptDeduplicator,
)
from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset_preparer import (
    ConceptPipelineDatasetPreparer,
)
from latent_reasoning_latents.concept_pipeline.implicit_concept_detector import (
    ImplicitConceptDetector,
)
from latent_reasoning_latents.concept_pipeline.input_clusterer import InputClusterer
from latent_reasoning_latents.concept_pipeline.variations_generator import (
    VariationsGenerator,
)
from latent_reasoning_latents.datasets.university_admission.loaders import (
    load_university_admission_dataset,
)
from latent_reasoning_latents.util import UNIVERSITY_ADMISSION_RESULTS_PATH

load_dotenv()

# %%

dataset_name = "university_admission"

# %%

# Load student applications (capped at 1500)
student_applications = load_university_admission_dataset()[:1500]
print(f"Loaded {len(student_applications)} student applications")

# %%

variations_generator = VariationsGenerator(
    system_prompt=UNIVERSITY_VARIATIONS_SYSTEM_PROMPT,
    generation_prompt=UNIVERSITY_VARIATIONS_GENERATION_PROMPT,
    result_parser=parse_university_variations_response,
)
implicit_concept_detector = ImplicitConceptDetector(num_concepts_per_input=12)
input_clusterer = InputClusterer()
concept_deduplicator = ConceptDeduplicator()

input_template = UNIVERSITY_ADMISSION_INPUT_TEMPLATE
assert isinstance(input_template, str) and len(input_template) > 0

varying_input_param_name = "student_application"

# Build varying inputs from application summaries
varying_inputs = [record.application_summary for record in student_applications]
assert len(varying_inputs) > 0

# No fixed parameters besides the varying student application
input_parameters: dict[str, str] = {}
# %%

# 2. Instantiate the main pipeline
output_dir = UNIVERSITY_ADMISSION_RESULTS_PATH
preparer = ConceptPipelineDatasetPreparer(
    variations_generator=variations_generator,
    implicit_concept_detector=implicit_concept_detector,
    input_clusterer=input_clusterer,
    output_dir=output_dir,
    parsed_labels_mapping={1: "Admitted", 0: "Rejected"},
    concept_deduplicator=concept_deduplicator,
)

# %%

try:
    # Try to get existing event loop (e.g. in Jupyter)
    loop = asyncio.get_running_loop()

    def _thread_runner():
        return asyncio.run(
            preparer.prepare_dataset(
                input_template,
                input_parameters,
                varying_input_param_name,
                varying_inputs=varying_inputs,
                dataset_name=dataset_name,
            )
        )

    # Run the async function in a thread to avoid "already running loop" errors
    with concurrent.futures.ThreadPoolExecutor() as ex:
        fut = ex.submit(_thread_runner)
        dataset = fut.result()
except RuntimeError as e:
    import traceback

    traceback.print_exc()
    print(f"RuntimeError: {e}")
    if "no running event loop" in str(e):
        # No running event loop – safe to run directly.
        dataset = asyncio.run(
            preparer.prepare_dataset(
                input_template,
                input_parameters,
                varying_input_param_name,
                varying_inputs=varying_inputs,
                dataset_name=dataset_name,
            )
        )

# %%

# Print deduplication statistics
print("\n" + "=" * 80)
print("DEDUPLICATION STATISTICS")
print("=" * 80)

if dataset.concepts is not None:
    print(f"\nOriginal concepts: {len(dataset.concepts)}")

if dataset.deduplicated_concepts is not None:
    print(f"Deduplicated concepts: {len(dataset.deduplicated_concepts)}")

if dataset.concepts is not None and dataset.deduplicated_concepts is not None:
    removed = len(dataset.concepts) - len(dataset.deduplicated_concepts)
    print(f"Concepts removed: {removed}")

if dataset.duplicate_groups is not None and len(dataset.duplicate_groups) > 0:
    print(f"\nDuplicate groups found: {len(dataset.duplicate_groups)}")

    for i, group in enumerate(dataset.duplicate_groups, 1):
        print("\n" + "━" * 80)
        print(f"DUPLICATE GROUP {i} - {len(group)} CONCEPTS")
        print("━" * 80)

        for concept_id in group:
            # Search in original concepts list since duplicate_groups includes removed concepts
            if dataset.concepts is not None:
                for c in dataset.concepts:
                    if c.id == concept_id:
                        print(f"\n  Concept ID: {concept_id}")
                        print(f"  Title: {c.title}")
                        print(
                            f"  Verbalization Check Guide: {c.verbalization_check_guide}"
                        )
                        print(f"  Removal Action: {c.removal_action}")
                        print(f"  Addition Action: {c.addition_action}")
                        print()
                        break

        print()  # Extra spacing after each group
else:
    print("\nNo duplicate groups found")

if dataset.deduplication_comparisons is not None:
    total_comparisons = len(dataset.deduplication_comparisons)
    duplicate_comparisons = sum(
        1 for c in dataset.deduplication_comparisons if c.duplicate
    )
    different_comparisons = total_comparisons - duplicate_comparisons

    print(f"\nTotal pairwise comparisons: {total_comparisons}")
    print(f"  - Marked as duplicates: {duplicate_comparisons}")
    print(f"  - Marked as different: {different_comparisons}")
    if total_comparisons > 0:
        dup_pct = (duplicate_comparisons / total_comparisons) * 100
        print(f"  - Duplicate rate: {dup_pct:.1f}%")

print("\n" + "=" * 80)

# %%
