import os
import argparse
import ast
import pandas as pd
from typing import Dict, List


def parse_k_passes(k_passes_str):
    try:
        if isinstance(k_passes_str, str):
            return ast.literal_eval(k_passes_str)
        return k_passes_str
    except (ValueError, SyntaxError):
        return []


def get_middle_percentile_models(results_df: pd.DataFrame) -> List[str]:
    model_scores = {}

    for model in results_df['ModelName'].unique():
        model_data = results_df[results_df['ModelName'] == model]
        total_correct = 0
        total_attempts = 0

        for _, row in model_data.iterrows():
            k_passes = parse_k_passes(row['K_Passes'])
            if k_passes:
                total_correct += sum(k_passes)
                total_attempts += len(k_passes)

        pass_k = total_correct / total_attempts if total_attempts > 0 else 0
        model_scores[model] = pass_k

    sorted_models = sorted(model_scores.items(), key=lambda x: x[1])
    total_models = len(sorted_models)

    start_idx = int(total_models * 0.20)
    end_idx = int(total_models * 0.80)

    middle_models = [model for model, _ in sorted_models[start_idx:end_idx]]

    return middle_models


def calculate_instance_difficulty(instance_id: str, results_df: pd.DataFrame, middle_models: List[str]) -> int:
    instance_results = results_df[
        (results_df['InstanceId'] == instance_id) &
        (results_df['ModelName'].isin(middle_models))
    ]

    if instance_results.empty:
        return 2

    model_difficulties = []

    for _, row in instance_results.iterrows():
        k_passes = parse_k_passes(row['K_Passes'])

        if not k_passes:
            model_difficulties.append(2)
            continue

        correct_count = sum(k_passes)
        total_passes = len(k_passes)
        success_rate = correct_count / total_passes

        if success_rate >= 0.8:
            model_difficulties.append(0)
        elif success_rate > 0:
            model_difficulties.append(1)
        else:
            model_difficulties.append(2)

    if not model_difficulties:
        return 2

    mean_difficulty = sum(model_difficulties) / len(model_difficulties)

    if mean_difficulty <= 0.8:
        return 0
    elif mean_difficulty <= 1.5:
        return 1
    else:
        return 2


def difficulty_to_label(difficulty: int) -> str:
    mapping = {0: "easy", 1: "medium", 2: "hard"}
    return mapping.get(difficulty, "hard")


def calculate_dataset_difficulty(dataset_csv: str, results_csv: str, output_csv: str):
    dataset_df = pd.read_csv(dataset_csv)
    results_df = pd.read_csv(results_csv)

    middle_models = get_middle_percentile_models(results_df)

    difficulties = []

    for _, row in dataset_df.iterrows():
        instance_id = row['ID']
        difficulty = calculate_instance_difficulty(
            instance_id, results_df, middle_models)
        difficulties.append(difficulty_to_label(difficulty))

    dataset_df['Difficulty'] = difficulties
    dataset_df.to_csv(output_csv, index=False)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset-csv', type=str, required=True)
    parser.add_argument('--results-csv', type=str, required=True)
    parser.add_argument('--output-csv', type=str, required=True)

    args = parser.parse_args()

    calculate_dataset_difficulty(
        args.dataset_csv, args.results_csv, args.output_csv)


if __name__ == "__main__":
    main()
