import json
import re
from pathlib import Path
from textwrap import dedent
from typing import Dict, List, Optional, Union

import pandas as pd
from scipy import stats
from tabulate import tabulate

from eliciting_contexts.benchmark.external.shared.utils import (
    compute_result_set_comparisons,
    filter_table,
    flatten_results,
    remove_entries,
)
from eliciting_contexts.utils.constants import BENCHMARK_RESULTS_DIR


def latex_from_table(
    table_data,
    headers,
    caption=(
        r"\textbf{Inpainting Stories Win Percentages.} Each cell gives the percentage "
        r"of stories in which the \emph{row} method achieves a better logit "
        r"diff than the \emph{column} method, \textit{when considering output "
        r"in the 3–9 cross entropy range}. (GCG not shown as none of its "
        r"outputs fall in this range.)"
    ),
    label="tab:results:stories_task",
):
    """
    Build a LaTeX table string from `table_data` and `headers`, stripping any
    numeric counts in parentheses (e.g. '39.4% (66)' → '39.4%') and keeping
    row/column order unchanged.

    Parameters
    ----------
    table_data : list[list[str]]
        Table body, each inner list is a row. First element is the row label.
    headers : list[str]
        Column titles. headers[0] must label the row‑label column.
    caption : str, optional
        Text for \caption{...}.  Defaults to the narrative caption supplied.
    label : str, optional
        Identifier for \label{...}. Defaults to 'tab:results:stories_task'.

    Returns
    -------
    str
        A complete LaTeX table environment ready to paste.
    """
    n_cols = len(headers)
    col_spec = "|".join(["c"] * n_cols)

    # Remove “ (number)” if present and escape the percent sign.
    def clean(cell: str) -> str:
        cell = re.sub(r"\s*\(\d+\)\s*$", "", cell)  # strip counts in parentheses
        return cell.replace("%", r"\%")

    # Header row
    header_row = " & ".join(rf"\textbf{{{clean(h)}}}" for h in headers) + r" \\"

    # Body rows (boldface the first cell of each row)
    body_rows = []
    for row in table_data:
        body_rows.append(
            " & ".join(
                rf"\textbf{{{clean(cell)}}}" if i == 0 else clean(cell)
                for i, cell in enumerate(row)
            )
            + r" \\"
        )
    body = "\n\\hline\n".join(body_rows)

    latex = rf"""
    \begin{{table}}[t]
    \centering
    \renewcommand{{\arraystretch}}{{1.3}}
    \setlength{{\tabcolsep}}{{4pt}}
    \begin{{tabular}}{{|{col_spec}|}}
    \multicolumn{{{n_cols}}}{{c}}{{\small\textit{{Row beats Column (\%)}}}}\\[2pt]
    \hline
    {header_row}
    \hline
    {body}
    \hline
    \end{{tabular}}
    \caption{{{caption}}}
    \label{{{label}}}
    \end{{table}}
    """

    return dedent(latex).strip()


def create_comparison_table(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    result_sets: Optional[List[str]] = None,
) -> str:
    """
    Create a formatted table showing pairwise comparisons between result sets.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to compare
        result_sets: Optional list of result sets to include in comparison

    Returns:
        Formatted table string
    """
    # Get comparison data
    comparison_data = compute_result_set_comparisons(
        df,
        target_metric=target_metric,
    )
    comparisons = comparison_data["comparisons"]

    # Use provided result sets or get all from data
    if result_sets is None:
        result_sets = sorted(df["result_set"].unique().tolist())

    # Create table data
    table_data = []
    headers = ["Method"] + result_sets

    for rs1 in result_sets:
        row = [rs1]
        for rs2 in result_sets:
            if rs1 == rs2:
                row.append("-")  # Self-comparison
            elif rs1 not in comparisons or rs2 not in comparisons[rs1]:
                row.append("N/A")  # Missing comparison data
            else:
                win_pct = comparisons[rs1][rs2]["win_percentage"]
                total = comparisons[rs1][rs2]["total"]
                row.append(f"{win_pct:.1f}% ({total})")
        table_data.append(row)

    # Create formatted table
    table = tabulate(
        table_data, headers=headers, tablefmt="grid", numalign="right", stralign="left"
    )
    latex_table = latex_from_table(table_data, headers)
    # save latex table to file
    with open("latex_table.txt", "w") as f:
        f.write(latex_table)
    return table


def create_summary_table(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    result_sets: Optional[List[str]] = None,
) -> str:
    """
    Create a summary table showing statistics for each result set.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to summarize
        fluency_metric: The fluency metric to filter on
        result_sets: Optional list of result sets to include

    Returns:
        Formatted table string
    """
    # Use provided result sets or get all from data
    if result_sets is None:
        result_sets = sorted(df["result_set"].unique().tolist())

    # Create table data
    table_data = []
    headers = ["Method", "Mean", "Median", "Std", "Min", "Max", "Count"]

    for result_set in result_sets:
        result_data = df[df["result_set"] == result_set][target_metric]

        row = [
            result_set,
            f"{result_data.mean():.3f}",
            f"{result_data.median():.3f}",
            f"{result_data.std():.3f}",
            f"{result_data.min():.3f}",
            f"{result_data.max():.3f}",
            len(result_data),
        ]
        table_data.append(row)

    # Create formatted table
    table = tabulate(
        table_data, headers=headers, tablefmt="grid", numalign="right", stralign="left"
    )

    return table


def perform_anova_analysis(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
) -> Dict[str, Dict[str, float]]:
    """
    Perform ANOVA analysis on target and fluency metrics across different result sets.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to analyze
        fluency_metric: The fluency metric to analyze

    Returns:
        Dictionary containing ANOVA results for both metrics
    """
    results = {}

    # Perform ANOVA for each metric
    for metric in [target_metric, fluency_metric]:
        # Get data for each result set
        groups = []
        for result_set in sorted(df["result_set"].unique()):
            group_data = df[df["result_set"] == result_set][metric].dropna()
            groups.append(group_data)

        # Perform ANOVA
        f_stat, p_value = stats.f_oneway(*groups)

        results[metric] = {"f_statistic": f_stat, "p_value": p_value}

    return results


def create_anova_table(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
) -> str:
    """
    Create a formatted table showing ANOVA results for both metrics.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to analyze
        fluency_metric: The fluency metric to analyze

    Returns:
        Formatted table string
    """
    # Get ANOVA results
    anova_results = perform_anova_analysis(
        df, target_metric=target_metric, fluency_metric=fluency_metric
    )

    # Create table data
    table_data = []
    headers = ["Metric", "F-statistic", "p-value"]

    for metric, results in anova_results.items():
        row = [metric, f"{results['f_statistic']:.4f}", f"{results['p_value']:.4e}"]
        table_data.append(row)

    # Create formatted table
    table = tabulate(
        table_data, headers=headers, tablefmt="grid", numalign="right", stralign="left"
    )

    return table


def save_tables(
    df: pd.DataFrame,
    target_metric: str = "logit_diff_improvement",
    fluency_metric: str = "cross_entropy",
    result_sets: Optional[List[str]] = None,
    benchmark_task_name: str = "tiny_stories",
    output_dir: Union[str, Path] = BENCHMARK_RESULTS_DIR / "tables",
    min_threshold_fluency: Optional[float] = None,
    max_threshold_fluency: Optional[float] = None,
) -> None:
    """
    Generate and save comparison, summary, and ANOVA tables.

    Args:
        df: DataFrame with evaluation results
        benchmark_task_name: Name of the benchmark task for organizing tables
        target_metric: The target metric to analyze
        fluency_metric: The fluency metric to filter on
        result_sets: Optional list of result sets to include
        min_threshold_fluency: Minimum threshold for fluency metric (None means no minimum)
        max_threshold_fluency: Maximum threshold for fluency metric (None means no maximum)
    """
    # Create output directory using standardized path
    output_dir = Path(output_dir) / benchmark_task_name
    output_dir.mkdir(parents=True, exist_ok=True)

    suffix = (
        f"_{min_threshold_fluency}_{max_threshold_fluency}"
        if min_threshold_fluency is not None or max_threshold_fluency is not None
        else ""
    )

    df_filtered = filter_table(
        df,
        target_metric=target_metric,
        fluency_metric=fluency_metric,
        only_include_best_in_range=True,
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
    )

    print(len(df_filtered))
    # Generate tables
    comparison_table = create_comparison_table(
        df_filtered, target_metric=target_metric, result_sets=result_sets
    )
    summary_table = create_summary_table(
        df, target_metric=target_metric, result_sets=result_sets
    )
    summary_table_filtered = create_summary_table(
        df_filtered, target_metric=target_metric, result_sets=result_sets
    )

    anova_table = create_anova_table(df, target_metric=target_metric)
    anova_table_filtered = create_anova_table(df_filtered, target_metric=target_metric)
    # Get raw ANOVA results for JSON
    anova_results = perform_anova_analysis(df, target_metric=target_metric)
    anova_results_filtered = perform_anova_analysis(
        df_filtered, target_metric=target_metric
    )

    # Save tables
    with open(output_dir / f"comparison_table{suffix}.txt", "w") as f:
        f.write(comparison_table)

    with open(output_dir / f"summary_table_full{suffix}.txt", "w") as f:
        f.write(summary_table)

    with open(output_dir / f"summary_table_filtered{suffix}.txt", "w") as f:
        f.write(summary_table_filtered)

    with open(output_dir / f"anova_table_full{suffix}.txt", "w") as f:
        f.write(anova_table)

    with open(output_dir / f"anova_table_filtered{suffix}.txt", "w") as f:
        f.write(anova_table_filtered)

    # Save raw ANOVA results as JSON
    with open(output_dir / f"anova_results_full{suffix}.json", "w") as f:
        json.dump(anova_results, f, indent=4)

    with open(output_dir / f"anova_results_filtered{suffix}.json", "w") as f:
        json.dump(anova_results_filtered, f, indent=4)

    print(f"Tables saved to {output_dir}")


if __name__ == "__main__":
    # Load data
    json_path = "/workspace/eliciting-contexts/src/eliciting_contexts/benchmark/results/results_final_latest_last.json"

    min_threshold_fluency = 3.0
    max_threshold_fluency = 100.0

    target_metric = "logit_diff_improvement"

    entries_to_remove = list(range(38, 77))
    entries_to_remove = [
        f"{i}" for i in entries_to_remove
    ]  # Add entries to remove here, e.g. ["entry1", "entry2"]

    # Each tuple is (json_name, display_name, color)
    result_sets_config = [
        ("human", "Human", "#d62728"),  # red
        ("GPT4o", "GPT4o", "#2ca02c"),  # green
        ("EPO", "EPO", "#1f77b4"),  # blue
        ("EPOAssistHobbled-Final", "EPO-Ast.", "#ff7f0e"),  # orange
        ("GCG", "GCG", "#9467bd"),  # purple
        # ("EPOAssistHobbled", "EPOAssistNormal", "#ff7f0e"),  # orange
        # ("EPOAssist", "EPOAssistsdfs", "#ff7f0e"),  # orange
        # ("EPOAssist-Final", "EPOAssistsdfsFinal", "#ff7f0e"),  # orange
        # ("GCG", "GCG", "#9467bd"),  # purple
    ]

    # Extract naming_map for flatten_results
    naming_map = {
        json_name: display_name for json_name, display_name, _ in result_sets_config
    }

    with open(json_path, "r") as f:
        all_results = json.load(f)

    # Remove specified entries before processing
    if entries_to_remove:
        print(f"Removing entries: {entries_to_remove}")
        all_results_filtered = remove_entries(all_results, entries_to_remove)

    # Pass the naming map to flatten_results to filter and rename in one step
    df_flattened = flatten_results(all_results_filtered, naming_map)

    print("Saving tables...")
    # Save tables
    save_tables(
        df_flattened,
        target_metric=target_metric,
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
    )
