import numpy as np

from src.code_evaluation.executor_standard import run_solutions
from src.model_querying.prompt_creation import (
    evaluate_single_test_case,
    generate_many_test_cases,
    get_difficulty_problem_solution_execution,
)
from src.model_querying.solution_extraction import (
    extract_difficulty,
    extract_single_test_case_evaluation_result,
    extract_test_cases_for_exec,
)
from src.pipeline.transformations import separate_unit_tests


def merge_unit_tests(data):
    joined_data = {}
    for key, value in data.items():
        problem_id = "-".join(key.split("-")[:-1])
        if problem_id not in joined_data:
            joined_data[problem_id] = value
        else:
            joined_data[problem_id]["test_cases"] += value["test_cases"]
    return joined_data


def average_difficulty(data):
    joined_data = {}
    for key, value in data.items():
        problem_id = "-".join(key.split("-")[:-1])
        if problem_id not in joined_data:
            joined_data[problem_id] = value
            joined_data[problem_id]["difficulty"] = [value["difficulty"]]
        else:
            joined_data[problem_id]["difficulty"].append(value["difficulty"])
    for problem_id in joined_data:
        joined_data[problem_id]["difficulty"] = np.mean(
            joined_data[problem_id]["difficulty"]
        )
    return joined_data


def add_execution_difficulty_estimation_to_pipeline(
    pipeline, strong_model, weak_model, get_solutions, team_suffix=None
):
    team = "blue"
    if team_suffix:
        team += f"_{team_suffix}"
    # generate unit tests
    generate_test_cases = pipeline.add_query_step(
        "generate_test_cases",
        strong_model,
        generate_many_test_cases,
        extract_test_cases_for_exec,
        dependencies=[get_solutions],
        team=team,
    )
    # execute them
    execute_generated_test_cases = pipeline.add_code_evaluation_step(
        "execute_gen_test_cases",
        weak_model,
        run_solutions,
        dependencies=[generate_test_cases],
    )
    split_by_unit_tests = pipeline.add_transformation_step(
        "split_by_unittests",
        separate_unit_tests,
        dependencies=[get_solutions, execute_generated_test_cases],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    evaluate_generated_test_cases = pipeline.add_query_step(
        "evaluate_generated_test_cases",
        weak_model,
        evaluate_single_test_case,
        extract_single_test_case_evaluation_result,
        dependencies=[split_by_unit_tests],
        team=team,
    )
    # merge evaluation results
    merged_evaluate_generated_test_cases = pipeline.add_transformation_step(
        "merge_unit_tests",
        merge_unit_tests,
        dependencies=[evaluate_generated_test_cases],
        strong_model=strong_model,
        weak_model=weak_model,
    )
    get_difficulty = pipeline.add_query_step(
        "get_difficulty",
        weak_model,
        get_difficulty_problem_solution_execution,
        extract_difficulty,
        dependencies=[merged_evaluate_generated_test_cases],
        team=team,
    )

    return get_difficulty
