from typing import Dict, List, Optional

import pandas as pd


def flatten_results(
    all_results: Dict, naming_map: Optional[Dict[str, str]] = None
) -> pd.DataFrame:
    """
    Convert nested results structure to a flat DataFrame for easier analysis.

    Args:
        all_results: The nested results dictionary
        naming_map: Optional dictionary mapping original result set names to desired display names.
                   If provided, only result sets in this map will be included.

    Returns:
        DataFrame with flattened results
    """
    flattened_data = []

    for idx, idx_data in all_results.items():
        # Extract info data once for this example
        info = idx_data.get("info", {})

        # Process all result sets except the 'info' key
        for result_set, results in idx_data.items():
            if result_set == "info":
                continue  # Skip the info dictionary

            # If naming_map is provided, filter and rename result sets
            if naming_map is not None:
                # Skip result sets not in the naming map
                if result_set not in naming_map:
                    continue

                # Use the mapped name for this result set
                display_name = naming_map[result_set]
            else:
                # No mapping, use original name
                display_name = result_set

            for i, result in enumerate(results):
                # Create the basic row data
                row_data = {
                    "idx": int(idx),
                    "result_set": display_name,  # Use the display name
                    "result_index": i,
                    "info": info,  # Store the entire info dict
                }

                # Add all fields from result directly - no assumptions about field names
                for key, value in result.items():
                    row_data[key] = value

                flattened_data.append(row_data)

    return pd.DataFrame(flattened_data)


def filter_table(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    only_include_best_in_range: bool = True,
    min_threshold_fluency: Optional[float] = None,
    max_threshold_fluency: Optional[float] = None,
) -> pd.DataFrame:
    """
    Filter results based on specified criteria.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to maximize when only_include_best_in_range is True
        fluency_metric: The fluency metric to filter on
        only_include_best_in_range: If True, keep only the point with maximum target metric value
        min_threshold_fluency: Minimum threshold for fluency metric (None means no minimum)
        max_threshold_fluency: Maximum threshold for fluency metric (None means no maximum)

    Returns:
        Filtered DataFrame
    """
    # First apply fluency filtering if thresholds are specified
    if min_threshold_fluency is not None or max_threshold_fluency is not None:
        mask = pd.Series(True, index=df.index)
        if min_threshold_fluency is not None:
            mask &= df[fluency_metric] >= min_threshold_fluency
        if max_threshold_fluency is not None:
            mask &= df[fluency_metric] < max_threshold_fluency
        df = df[mask]

    # If we don't need to select best in range, return the filtered dataframe
    if not only_include_best_in_range:
        return df

    filtered_rows = []

    # Group by idx and result_set
    for (idx, result_set), group_df in df.groupby(["idx", "result_set"]):
        # Get the row with the maximum target metric
        max_target_row = group_df.loc[group_df[target_metric].idxmax()]
        filtered_rows.append(max_target_row)

    # Create a new DataFrame from the filtered rows
    if filtered_rows:
        return pd.DataFrame(filtered_rows)
    else:
        # Return empty DataFrame with same columns if no rows match criteria
        return pd.DataFrame(columns=df.columns)


def compute_result_set_comparisons(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
) -> Dict:
    """
    Computes comparison statistics between result sets.

    For each example, compares target_metric values across pairs of result sets.
    Assumes input DataFrame has already been filtered to have one value per method per example.

    Args:
        df: DataFrame with evaluation results (should be pre-filtered to have one value per method per example)
        target_metric: The target metric to compare

    Returns:
        Dictionary containing comparison statistics
    """
    # Get unique examples and result sets
    examples = sorted(df["idx"].unique())
    result_sets = sorted(df["result_set"].unique())

    # Dictionary to store target_metric values per example per result set
    target_values = {}

    # For each example, get target_metric values
    for idx in examples:
        target_values[idx] = {}
        idx_data = df[df["idx"] == idx]

        for result_set in result_sets:
            # Get data for this result set
            result_set_data = idx_data[idx_data["result_set"] == result_set]

            # Since we know we have at most one value per method per example,
            # we can just take the first value if it exists
            if not result_set_data.empty and target_metric in result_set_data.columns:
                target_values[idx][result_set] = result_set_data[target_metric].iloc[0]
            else:
                target_values[idx][result_set] = None

    # Compute pairwise comparisons
    comparisons = {}

    for rs1 in result_sets:
        comparisons[rs1] = {}
        for rs2 in result_sets:
            if rs1 != rs2:  # Don't compare with self
                win_count = 0
                total = 0

                for idx in examples:
                    v1 = target_values[idx].get(rs1)
                    v2 = target_values[idx].get(rs2)

                    # Skip if both are None
                    if v1 is None and v2 is None:
                        continue

                    # Count rs1 as winner if its target_metric is higher
                    # Treat None as a very low value
                    if v1 is not None and (v2 is None or v1 > v2):
                        win_count += 1

                    total += 1

                # Calculate percentage (or 0 if no valid comparisons)
                win_percentage = (win_count / total * 100) if total > 0 else 0
                comparisons[rs1][rs2] = {
                    "win_percentage": win_percentage,
                    "win_count": win_count,
                    "total": total,
                }
            else:
                # Add explicit self-comparison entry with zeros to ensure consistent structure
                comparisons[rs1][rs2] = {
                    "win_percentage": 0,
                    "win_count": 0,
                    "total": 0,
                }

    return {"target_values": target_values, "comparisons": comparisons}


def remove_entries(
    all_results: Dict[str, Dict], entries_to_remove: List[str]
) -> Dict[str, Dict]:
    """
    Remove specific entries from the results dictionary.

    Args:
        all_results: Dictionary containing all results
        entries_to_remove: List of entry names to remove

    Returns:
        Dictionary with specified entries removed
    """
    filtered_results = all_results.copy()
    for entry in entries_to_remove:
        if entry in filtered_results:
            del filtered_results[entry]
    return filtered_results
