"""
Token analysis utilities for visualization generation.
"""

import pandas as pd
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats


def prepare_token_data(data: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare evaluation data for token analysis by extracting token metrics.

    Parameters:
    - data: DataFrame with evaluation results including token_usage

    Returns:
    - DataFrame with flattened token metrics
    """
    token_rows = []

    for _, row in data.iterrows():
        token_usage = row.get("token_usage")
        if not token_usage or not isinstance(token_usage, dict):
            continue

        # Create a flattened row with token metrics
        token_row = row.to_dict()

        # Add token metrics
        token_row.update(
            {
                "input_tokens": token_usage.get("input_tokens") or 0,
                "output_tokens": token_usage.get("output_tokens") or 0,
                "total_tokens": token_usage.get("total_tokens") or 0,
                "reasoning_tokens": token_usage.get("reasoning_tokens") or 0,
                "answer_tokens": token_usage.get("answer_tokens") or 0,
                "reasoning_ratio": token_usage.get("reasoning_ratio") or 0,
                "has_thinking_section": token_usage.get("has_thinking_section") or False,
                "estimation_method": token_usage.get("estimation_method") or "unknown",
            }
        )

        # Calculate derived metrics
        if token_row["total_tokens"] > 0:
            token_row["efficiency_score"] = (
                token_row["correct"] / token_row["total_tokens"] * 1000
            )  # per 1k tokens
        else:
            token_row["efficiency_score"] = 0

        if token_row["reasoning_tokens"] > 0:
            token_row["reasoning_efficiency"] = (
                token_row["correct"] / token_row["reasoning_tokens"] * 100
            )  # per 100 reasoning tokens
        else:
            token_row["reasoning_efficiency"] = 0

        token_rows.append(token_row)

    return pd.DataFrame(token_rows)


def calculate_task_difficulty(data: pd.DataFrame) -> pd.DataFrame:
    """
    Calculate task difficulty metrics based on accuracy and token usage.

    Parameters:
    - data: DataFrame with token data

    Returns:
    - DataFrame with task difficulty metrics
    """
    if "benchmark" not in data.columns:
        return pd.DataFrame()

    task_stats = (
        data.groupby("benchmark")
        .agg(
            {
                "correct": ["mean", "count"],
                "total_tokens": ["mean", "std"],
                "reasoning_tokens": ["mean", "std"],
                "reasoning_ratio": ["mean", "std"],
            }
        )
        .round(3)
    )

    # Flatten column names
    task_stats.columns = ["_".join(col).strip() for col in task_stats.columns.values]
    task_stats = task_stats.reset_index()

    # Calculate difficulty score (lower accuracy + higher token usage = harder)
    task_stats["difficulty_score"] = (
        1 - task_stats["correct_mean"]
    ) * 0.6 + (  # Accuracy weight
        task_stats["total_tokens_mean"] / task_stats["total_tokens_mean"].max()
    ) * 0.4  # Token weight

    return task_stats.sort_values("difficulty_score", ascending=False)


def create_task_difficulty_analysis(
    data: pd.DataFrame, figsize: Tuple[int, int] = (16, 8)
) -> plt.Figure:
    """
    Create analysis of task difficulty vs reasoning token requirements.

    Parameters:
    - data: DataFrame with token data
    - figsize: Figure size

    Returns:
    - matplotlib Figure
    """
    task_stats = calculate_task_difficulty(data)

    if task_stats.empty:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            "No task data available for analysis",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
        )
        ax.set_title("Task Difficulty vs Reasoning Requirements")
        return fig

    # Create subplot with two y-axes
    fig, ax1 = plt.subplots(figsize=figsize)
    ax2 = ax1.twinx()

    # Sort by difficulty for better visualization
    task_stats = task_stats.sort_values("correct_mean")

    x_pos = np.arange(len(task_stats))

    # Plot accuracy (left y-axis)
    ax1.bar(
        x_pos - 0.2,
        task_stats["correct_mean"],
        width=0.4,
        label="Accuracy",
        color="skyblue",
        alpha=0.7,
    )

    # Plot average reasoning tokens (right y-axis)
    ax2.bar(
        x_pos + 0.2,
        task_stats["reasoning_tokens_mean"],
        width=0.4,
        label="Avg Reasoning Tokens",
        color="lightcoral",
        alpha=0.7,
    )

    # Add value labels on bars
    for i, (acc, tokens) in enumerate(
        zip(task_stats["correct_mean"], task_stats["reasoning_tokens_mean"])
    ):
        ax1.text(
            i - 0.2, acc + 0.01, f"{acc:.3f}", ha="center", va="bottom", fontsize=8
        )
        ax2.text(
            i + 0.2,
            tokens + tokens * 0.01,
            f"{int(tokens)}",
            ha="center",
            va="bottom",
            fontsize=8,
        )

    # Customize plot
    ax1.set_xlabel("Tasks (sorted by difficulty)")
    ax1.set_ylabel("Accuracy", color="blue")
    ax2.set_ylabel("Average Reasoning Tokens", color="red")
    ax1.set_title("Do Harder Tasks Require More Reasoning Tokens?")

    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(task_stats["benchmark"], rotation=45, ha="right")
    ax1.set_ylim(0, 1)

    # Combine legends from both axes and place at top center
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    fig.legend(
        handles1 + handles2,
        labels1 + labels2,
        loc="upper center",
        bbox_to_anchor=(0.5, 0.95),
        ncol=2,
    )
    # Adjust layout to reserve space above the axes for the legend and title
    fig.tight_layout(rect=[0, 0, 1, 0.85])

    # Add correlation annotation
    if len(task_stats) > 2:
        correlation = stats.pearsonr(
            task_stats["correct_mean"], task_stats["reasoning_tokens_mean"]
        )
        ax1.text(
            0.02,
            0.98,
            f"Correlation: {correlation[0]:.3f} (p={correlation[1]:.3f})",
            transform=ax1.transAxes,
            va="top",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
        )

    return fig


def create_question_type_reasoning_analysis(
    data: pd.DataFrame, figsize: Tuple[int, int] = (14, 8)
) -> plt.Figure:
    """
    Create analysis of reasoning effort by question type.

    Parameters:
    - data: DataFrame with token data
    - figsize: Figure size

    Returns:
    - matplotlib Figure
    """
    if "question_type" not in data.columns:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            "No question type data available",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
        )
        ax.set_title("Reasoning Effort by Question Type")
        return fig

    # Filter data with reasoning tokens
    reasoning_data = data[
        (data["reasoning_tokens"] > 0) & (data["reasoning_tokens"].notna())
    ].copy()

    if reasoning_data.empty:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(
            0.5,
            0.5,
            "No reasoning token data available",
            ha="center",
            va="center",
            transform=ax.transAxes,
            fontsize=14,
        )
        ax.set_title("Reasoning Effort by Question Type")
        return fig

    # Create box plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Left plot: Reasoning tokens by question type
    question_types = sorted(reasoning_data["question_type"].unique())
    reasoning_by_qtype = [
        reasoning_data[reasoning_data["question_type"] == qt]["reasoning_tokens"].values
        for qt in question_types
    ]

    box_plot = ax1.boxplot(reasoning_by_qtype, labels=question_types, patch_artist=True)

    # Color the boxes
    colors = plt.cm.Set3(np.linspace(0, 1, len(question_types)))
    for patch, color in zip(box_plot["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    ax1.set_xlabel("Question Type")
    ax1.set_ylabel("Reasoning Tokens")
    ax1.set_title("Reasoning Token Distribution by Question Type")
    ax1.tick_params(axis="x", rotation=45)

    # Right plot: Success rate by reasoning intensity
    reasoning_data["reasoning_category"] = pd.cut(
        reasoning_data["reasoning_tokens"],
        bins=[0, 500, 1000, 2000, float("inf")],
        labels=["Low (0-500)", "Medium (500-1000)", "High (1000-2000)", "Very High (2000+)"],
    )

    success_by_reasoning = (
        reasoning_data.groupby(["question_type", "reasoning_category"])["correct"]
        .agg(["mean", "count"])
        .reset_index()
    )
    success_by_reasoning = success_by_reasoning[
        success_by_reasoning["count"] >= 3
    ]  # Filter for sufficient data

    # Create a pivot table for heatmap
    if not success_by_reasoning.empty:
        pivot_data = success_by_reasoning.pivot(
            index="question_type", columns="reasoning_category", values="mean"
        )

        sns.heatmap(
            pivot_data,
            annot=True,
            fmt=".3f",
            cmap="YlOrRd",
            ax=ax2,
            cbar_kws={"label": "Success Rate"},
        )
        ax2.set_title("Success Rate by Question Type and Reasoning Effort")
        ax2.set_xlabel("Reasoning Category")
        ax2.set_ylabel("Question Type")
    else:
        ax2.text(
            0.5,
            0.5,
            "Insufficient data for heatmap",
            ha="center",
            va="center",
            transform=ax2.transAxes,
        )
        ax2.set_title("Success Rate by Reasoning Effort")

    plt.tight_layout()
    return fig
