import argparse
import pandas as pd
import ast


def calculate_model_scores(evaluation_csv, dataset_csv):
    eval_df = pd.read_csv(evaluation_csv)
    dataset_df = pd.read_csv(dataset_csv)

    eval_df['K_Passes'] = eval_df['K_Passes'].apply(ast.literal_eval)
    eval_df['Pass1'] = eval_df['K_Passes'].apply(
        lambda x: x[0] if len(x) > 0 else 0)

    merged_df = eval_df.merge(
        dataset_df[['ID', 'Difficulty']], left_on='InstanceId', right_on='ID')

    model_scores = {}

    for model in merged_df['ModelName'].unique():
        model_data = merged_df[merged_df['ModelName'] == model]

        easy_score = model_data[model_data['Difficulty']
                                == 'easy']['Pass1'].mean() * 100
        medium_score = model_data[model_data['Difficulty']
                                  == 'medium']['Pass1'].mean() * 100
        hard_score = model_data[model_data['Difficulty']
                                == 'hard']['Pass1'].mean() * 100
        overall_score = model_data['Pass1'].mean() * 100

        model_scores[model] = {
            'easy': easy_score,
            'medium': medium_score,
            'hard': hard_score,
            'overall': overall_score
        }

    return model_scores


def get_top_models(model_scores, n=6):
    sorted_models = sorted(model_scores.items(),
                           key=lambda x: x[1]['overall'], reverse=True)
    return dict(sorted_models[:n])


def format_model_name(model_name):
    name_map = {
        'openai/chatgpt-4o-latest': 'GPT-4o',
        'openai/o4-mini': 'o4-mini',
        'google/gemini-2.5-pro': 'Gemini-Pro',
        'google/gemini-2.5-flash': 'Gemini-Flash',
        'anthropic/claude-sonnet-4': 'Claude-Sonnet',
        'anthropic/claude-opus-4': 'Claude-Opus',
        'qwen/qwen2.5-vl-72b-instruct': 'Qwen2.5-VL',
        'meta-llama/llama-4-maverick': 'Llama-4',
        'mistralai/mistral-medium-3': 'Mistral-Medium',
        'microsoft/phi-4-multimodal-instruct': 'Phi-4'
    }
    return name_map.get(model_name, model_name.split('/')[-1])


def generate_tikz_code(top_models):
    model_names = list(top_models.keys())
    short_names = [format_model_name(name) for name in model_names]

    tikz_code = []
    tikz_code.append("\\begin{tikzpicture}")
    tikz_code.append("\\begin{axis}[")
    tikz_code.append("    ybar,")
    tikz_code.append("    bar width=4pt,")
    tikz_code.append("    enlarge x limits=0.18,")
    tikz_code.append("    enlarge y limits={upper,value=0.2},")
    tikz_code.append("    ylabel={Score},")
    tikz_code.append(f"    symbolic x coords={{{', '.join(short_names)}}},")
    tikz_code.append("    xtick=data,")
    tikz_code.append("    ymin=0, ymax=60,")
    tikz_code.append("    axis y line*=left,")
    tikz_code.append("    axis x line*=bottom,")
    tikz_code.append("    legend style={")
    tikz_code.append("        at={(0.95,1.10)},")
    tikz_code.append("        anchor=north east,")
    tikz_code.append("        font=\\scriptsize,")
    tikz_code.append("        draw=none,")
    tikz_code.append("        fill=white,")
    tikz_code.append("        fill opacity=0.8,")
    tikz_code.append("        text opacity=1,")
    tikz_code.append("        inner sep=4pt")
    tikz_code.append("    },")
    tikz_code.append("    legend columns=1,")
    tikz_code.append("    width=0.95\\linewidth,")
    tikz_code.append("    height=0.95\\linewidth,")
    tikz_code.append("    tick label style={font=\\scriptsize}")
    tikz_code.append("]")
    tikz_code.append("")

    easy_coords = []
    medium_coords = []
    hard_coords = []
    overall_coords = []
    overall_labels = []

    for i, (model_name, scores) in enumerate(top_models.items()):
        short_name = short_names[i]
        easy_coords.append(f"({short_name},{scores['easy']:.0f})")
        medium_coords.append(f"({short_name},{scores['medium']:.0f})")
        hard_coords.append(f"({short_name},{scores['hard']:.0f})")
        overall_coords.append(f"({short_name},{scores['overall']:.2f})")
        overall_labels.append(
            f"\\node at (axis cs:{short_name},{scores['overall']+2:.0f}) {{\\scriptsize {scores['overall']:.1f}}};")

    tikz_code.append(
        "\\addplot+[style={fill=green!30}] coordinates {" + " ".join(easy_coords) + "};")
    tikz_code.append("\\addlegendentry{Easy Bin}")
    tikz_code.append("")

    tikz_code.append(
        "\\addplot+[style={fill=yellow!40}] coordinates {" + " ".join(medium_coords) + "};")
    tikz_code.append("\\addlegendentry{Medium Bin}")
    tikz_code.append("")

    tikz_code.append(
        "\\addplot+[style={fill=red!30}] coordinates {" + " ".join(hard_coords) + "};")
    tikz_code.append("\\addlegendentry{Hard Bin}")
    tikz_code.append("")

    tikz_code.append(
        "\\addplot+[sharp plot, color=blue, mark=*, thick] coordinates {")
    tikz_code.append("    " + " ".join(overall_coords))
    tikz_code.append("};")
    tikz_code.append("\\addlegendentry{Overall Ability}")
    tikz_code.append("")

    for label in overall_labels:
        tikz_code.append(label)

    tikz_code.append("")
    tikz_code.append("\\end{axis}")
    tikz_code.append("\\end{tikzpicture}")

    return "\n".join(tikz_code)


def main():
    parser = argparse.ArgumentParser(
        description="Generate TikZ code for model difficulty distribution")
    parser.add_argument(
        "--evaluation-csv",
        type=str,
        required=True,
        help="Path to evaluation results CSV"
    )
    parser.add_argument(
        "--dataset-csv",
        type=str,
        required=True,
        help="Path to dataset CSV with difficulty labels"
    )
    parser.add_argument(
        "--top-n",
        type=int,
        default=6,
        help="Number of top models to include"
    )

    args = parser.parse_args()

    model_scores = calculate_model_scores(
        args.evaluation_csv, args.dataset_csv)
    top_models = get_top_models(model_scores, args.top_n)
    tikz_code = generate_tikz_code(top_models)
    print(tikz_code)


if __name__ == "__main__":
    main()
