import argparse
import json
import os
from datetime import datetime
from typing import Dict, List, Tuple

import pandas as pd
from jinja2 import Environment, FileSystemLoader


def load_results(json_path: str) -> Dict:
    """Load the evaluation results from the JSON file."""
    with open(json_path, "r") as f:
        return json.load(f)


def flatten_results(all_results: Dict) -> pd.DataFrame:
    """Convert nested results structure to a flat DataFrame for easier analysis."""
    flattened_data = []

    for idx, idx_data in all_results.items():
        # Extract info data once for this example
        info = idx_data.get("info", {})

        # Process all result sets except the 'info' key
        for result_set, results in idx_data.items():
            if result_set == "info":
                continue  # Skip the info dictionary

            for i, result in enumerate(results):
                # Create the basic row data
                row_data = {
                    "idx": int(idx),
                    "result_set": result_set,
                    "result_index": i,
                    "info": info,  # Store the entire info dict
                }

                # Add all fields from result directly - no assumptions about field names
                for key, value in result.items():
                    row_data[key] = value

                flattened_data.append(row_data)

    return pd.DataFrame(flattened_data)


def compute_result_set_comparisons(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    fluency_range: Tuple[float, float] = (3, 9),
) -> Dict:
    """
    Computes comparison statistics between result sets.

    For each example:
    1. Finds max target_metric for entries with fluency_metric between specified range
    2. Compares these max values across pairs of result sets

    Returns a dictionary with comparison statistics.
    """
    # Get unique examples and result sets
    examples = sorted(df["idx"].unique())
    result_sets = sorted(df["result_set"].unique())

    # Dictionary to store max target_metric values per example per result set
    max_target_values = {}

    # For each example, find max target_metric in the specified fluency_metric range
    for idx in examples:
        max_target_values[idx] = {}
        idx_data = df[df["idx"] == idx]

        for result_set in result_sets:
            # Filter data for this result set
            result_set_data = idx_data[idx_data["result_set"] == result_set]

            # Find entries with fluency_metric between specified range
            filtered_data = result_set_data[
                (result_set_data[fluency_metric] >= fluency_range[0])
                & (result_set_data[fluency_metric] < fluency_range[1])
            ]

            # Get max target_metric (or None if no entries)
            if not filtered_data.empty and target_metric in filtered_data.columns:
                valid_values = filtered_data[target_metric].dropna()
                if not valid_values.empty:
                    max_target_values[idx][result_set] = valid_values.max()
                else:
                    max_target_values[idx][result_set] = None
            else:
                max_target_values[idx][result_set] = None

    # Compute pairwise comparisons
    comparisons = {}

    for rs1 in result_sets:
        comparisons[rs1] = {}
        for rs2 in result_sets:
            if rs1 != rs2:  # Don't compare with self
                win_count = 0
                total = 0

                for idx in examples:
                    v1 = max_target_values[idx].get(rs1)
                    v2 = max_target_values[idx].get(rs2)

                    # Skip if both are None
                    if v1 is None and v2 is None:
                        continue

                    # Count rs1 as winner if its target_metric is higher
                    # Treat None as a very low value
                    if v1 is not None and (v2 is None or v1 > v2):
                        win_count += 1

                    total += 1

                # Calculate percentage (or 0 if no valid comparisons)
                win_percentage = (win_count / total * 100) if total > 0 else 0
                comparisons[rs1][rs2] = {
                    "win_percentage": win_percentage,
                    "win_count": win_count,
                    "total": total,
                }
            else:
                # Add explicit self-comparison entry with zeros to ensure consistent structure
                comparisons[rs1][rs2] = {
                    "win_percentage": 0,
                    "win_count": 0,
                    "total": 0,
                }

    return {"max_target_values": max_target_values, "comparisons": comparisons}


def get_numeric_metrics(df: pd.DataFrame) -> List[str]:
    """Identify numeric columns that are likely metrics."""
    # Exclude specific columns that aren't metrics
    excluded_columns = {"idx", "result_set", "result_index"}

    # Find all numeric columns
    numeric_columns = []
    for col in df.columns:
        if col in excluded_columns:
            continue
        if pd.api.types.is_numeric_dtype(df[col]) and not df[col].isna().all():
            numeric_columns.append(col)

    return numeric_columns


def create_charts_data(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
) -> Dict:
    """Create data for the charts section."""
    charts_data = {}

    # Generate result set comparison data
    if target_metric in df.columns and fluency_metric in df.columns:
        # Get the raw comparison statistics
        comparison_data = compute_result_set_comparisons(
            df, target_metric=target_metric, fluency_metric=fluency_metric
        )

        # Add just the raw comparison data for the table
        if comparison_data and "comparisons" in comparison_data:
            charts_data["result_set_comparison"] = {
                "comparisons": comparison_data["comparisons"]
            }

    # Check if both target_metric and fluency_metric exist in the dataframe
    if target_metric in df.columns and fluency_metric in df.columns:
        # Define a colorscale for consistent colors across charts
        colors = [
            "#1f77b4",  # blue
            "#ff7f0e",  # orange
            "#2ca02c",  # green
            "#d62728",  # red
            "#9467bd",  # purple
            "#8c564b",  # brown
            "#e377c2",  # pink
            "#7f7f7f",  # gray
            "#bcbd22",  # olive
            "#17becf",  # cyan
        ]

        # Sort result sets
        result_sets = sorted(df["result_set"].unique().tolist())

        # Create scatter plot data for fluency_metric vs target_metric
        scatter_data = []

        # Create marginal histograms data
        x_histograms = []  # For fluency_metric marginals
        y_histograms = []  # For target_metric marginals

        for i, result_set in enumerate(result_sets):
            result_set_data = df[df["result_set"] == result_set]

            if (
                not result_set_data.empty
                and fluency_metric in result_set_data.columns
                and target_metric in result_set_data.columns
            ):
                # Create hover text with example idx information
                hover_texts = []
                x_values = []
                y_values = []

                for _, row in result_set_data.iterrows():
                    if pd.notna(row.get(fluency_metric)) and pd.notna(
                        row.get(target_metric)
                    ):
                        ce = row[fluency_metric]
                        ld = row[target_metric]
                        idx = row["idx"]

                        input_text = row.get("input_text", "")

                        hover_text = f"Example {idx}<br>Input: {input_text}<br>{fluency_metric}: {ce:.4f}<br>{target_metric}: {ld:.4f}"
                        hover_texts.append(hover_text)
                        x_values.append(ce)
                        y_values.append(ld)

                if x_values:  # Only add trace if we have data
                    # Main scatter trace
                    trace = {
                        "x": x_values,
                        "y": y_values,
                        "mode": "markers",
                        "type": "scatter",
                        "name": result_set,
                        "text": hover_texts,
                        "hoverinfo": "text",
                        "marker": {
                            "size": 8,
                            "opacity": 0.7,
                            "color": colors[i % len(colors)],
                        },
                        "xaxis": "x",
                        "yaxis": "y",
                    }
                    scatter_data.append(trace)

                    # Add marginal x histogram (top)
                    x_histogram = {
                        "x": x_values,
                        "type": "histogram",
                        "name": result_set,
                        "marker": {"color": colors[i % len(colors)], "opacity": 0.7},
                        "xaxis": "x",
                        "yaxis": "y2",
                        "showlegend": False,
                    }
                    x_histograms.append(x_histogram)

                    # Add marginal y histogram (right)
                    y_histogram = {
                        "y": y_values,
                        "type": "histogram",
                        "name": result_set,
                        "marker": {"color": colors[i % len(colors)], "opacity": 0.7},
                        "xaxis": "x2",
                        "yaxis": "y",
                        "showlegend": False,
                    }
                    y_histograms.append(y_histogram)

        # Combine all traces
        all_traces = scatter_data + x_histograms + y_histograms

        # Create layout with subplots for marginals
        scatter_layout = {
            "title": f"{fluency_metric.replace('_', ' ').title()} vs {target_metric.replace('_', ' ').title()}",
            "xaxis": {
                "title": fluency_metric.replace("_", " ").title(),
                "domain": [0, 0.85],
            },
            "yaxis": {
                "title": target_metric.replace("_", " ").title(),
                "domain": [0, 0.85],
            },
            "xaxis2": {
                "domain": [0.87, 1],
                "showgrid": False,
                "zeroline": False,
                "showticklabels": False,
            },
            "yaxis2": {
                "domain": [0.87, 1],
                "showgrid": False,
                "zeroline": False,
                "showticklabels": False,
            },
            "bargap": 0.1,
            "hovermode": "closest",
            "showlegend": True,
            "legend": {"orientation": "h", "y": -0.2},
        }

        charts_data["metric_scatter"] = {
            "data": all_traces,
            "layout": scatter_layout,
        }

        # Determine a common bin configuration for all histograms
        all_values = df[target_metric].dropna().tolist()
        if all_values:
            # Calculate min and max for consistent binning
            min_value = min(all_values)
            max_value = max(all_values)
            # Add a small margin to ensure all values are included
            margin = (max_value - min_value) * 0.05
            min_edge = min_value - margin
            max_edge = max_value + margin

            # Create uniform bin edges with 30 bins
            bin_size = (max_edge - min_edge) / 30
            bin_edges = [min_edge + i * bin_size for i in range(31)]
        else:
            bin_edges = None

        # Define the fluency_metric ranges
        ce_ranges = [
            {"name": "all", "title": "All Data", "filter": None},
            {
                "name": "low",
                "title": f"Low {fluency_metric.replace('_', ' ').title()} (0-3)",
                "filter": lambda ce: 0 <= ce < 3,
            },
            {
                "name": "medium",
                "title": f"Medium {fluency_metric.replace('_', ' ').title()} (3-9)",
                "filter": lambda ce: 3 <= ce < 9,
            },
            {
                "name": "high",
                "title": f"High {fluency_metric.replace('_', ' ').title()} (9+)",
                "filter": lambda ce: ce >= 9,
            },
        ]

        # Create histograms for each fluency_metric range
        for ce_range in ce_ranges:
            # Prepare data for the histogram
            histogram_data = []

            for i, result_set in enumerate(result_sets):
                # Filter data for this result set
                result_set_data = df[df["result_set"] == result_set]

                # Apply fluency_metric filter if specified
                if ce_range["filter"] is not None:
                    # We need both target_metric and fluency_metric to be non-null
                    filtered_data = result_set_data.dropna(
                        subset=[target_metric, fluency_metric]
                    )

                    # Apply the filter function to fluency_metric
                    mask = filtered_data[fluency_metric].apply(ce_range["filter"])
                    filtered_data = filtered_data[mask]

                    # Extract target_metric values for the histogram
                    metric_values = filtered_data[target_metric].tolist()
                else:
                    # For "all" data, just use non-null target_metric values
                    metric_values = result_set_data[target_metric].dropna().tolist()

                if metric_values:
                    # Create a trace for each result set
                    trace = {
                        "x": metric_values,
                        "name": result_set,
                        "type": "histogram",
                        "opacity": 0.7,
                        "histnorm": "percent",  # Show as percentage
                        "xbins": {
                            "start": bin_edges[0],
                            "end": bin_edges[-1],
                            "size": bin_size,
                        },
                        "autobinx": False,  # Disable automatic bin size calculation
                        "marker": {"color": colors[i % len(colors)]},
                    }

                    histogram_data.append(trace)

            # Add layout configuration for histogram
            histogram_layout = {
                "title": f"Distribution of {target_metric.replace('_', ' ').title()} Values - {ce_range['title']}",
                "xaxis": {"title": target_metric.replace("_", " ").title()},
                "yaxis": {"title": "Percentage (%)"},
                "barmode": "overlay",  # Overlay histograms for better comparison
                "bargap": 0.1,
                "legend": {"orientation": "h", "y": -0.2},
            }

            charts_data[f"{target_metric}_histogram_{ce_range['name']}"] = {
                "data": histogram_data,
                "layout": histogram_layout,
            }

        # Add histograms for fluency metric distribution
        # Determine a common bin configuration for fluency metric histograms
        all_fluency_values = df[fluency_metric].dropna().tolist()
        if all_fluency_values:
            # Calculate min and max for consistent binning
            min_fluency = min(all_fluency_values)
            max_fluency = max(all_fluency_values)
            # Add a small margin to ensure all values are included
            margin = (max_fluency - min_fluency) * 0.05
            min_edge = min_fluency - margin
            max_edge = max_fluency + margin

            # Create uniform bin edges with 30 bins
            bin_size = (max_edge - min_edge) / 30
            fluency_bin_edges = [min_edge + i * bin_size for i in range(31)]
        else:
            fluency_bin_edges = None

        # Create fluency metric histogram data
        fluency_histogram_data = []

        for i, result_set in enumerate(result_sets):
            # Filter data for this result set
            result_set_data = df[df["result_set"] == result_set]

            # Extract fluency_metric values for the histogram
            fluency_values = result_set_data[fluency_metric].dropna().tolist()

            if fluency_values:
                # Create a trace for each result set
                trace = {
                    "x": fluency_values,
                    "name": result_set,
                    "type": "histogram",
                    "opacity": 0.7,
                    "histnorm": "percent",  # Show as percentage
                    "xbins": {
                        "start": fluency_bin_edges[0],
                        "end": fluency_bin_edges[-1],
                        "size": bin_size,
                    },
                    "autobinx": False,  # Disable automatic bin size calculation
                    "marker": {"color": colors[i % len(colors)]},
                }

                fluency_histogram_data.append(trace)

        # Add layout configuration for fluency histogram
        fluency_histogram_layout = {
            "title": f"Distribution of {fluency_metric.replace('_', ' ').title()} Values",
            "xaxis": {"title": fluency_metric.replace("_", " ").title()},
            "yaxis": {"title": "Percentage (%)"},
            "barmode": "overlay",  # Overlay histograms for better comparison
            "bargap": 0.1,
            "legend": {"orientation": "h", "y": -0.2},
        }

        charts_data[f"{fluency_metric}_histogram"] = {
            "data": fluency_histogram_data,
            "layout": fluency_histogram_layout,
        }

        # Add violin plots for target_metric (all data)
        target_violin_data = []
        for i, result_set in enumerate(result_sets):
            result_set_data = df[df["result_set"] == result_set]
            values = result_set_data[target_metric].dropna().tolist()

            if values:
                violin = {
                    "type": "violin",
                    "y": values,
                    "name": result_set,
                    "box": {"visible": True},
                    "meanline": {"visible": True},
                    "line": {"color": colors[i % len(colors)]},
                    "fillcolor": "",
                    "opacity": 0.6,
                    "points": "outliers",
                }
                target_violin_data.append(violin)

        target_violin_layout = {
            "title": f"Distribution of {target_metric.replace('_', ' ').title()} by Result Set",
            "yaxis": {"title": target_metric.replace("_", " ").title()},
            "violingap": 0,
            "violingroupgap": 0,
            "violinmode": "group",
            "legend": {"orientation": "h", "y": -0.2},
        }

        charts_data[f"{target_metric}_violin"] = {
            "data": target_violin_data,
            "layout": target_violin_layout,
        }

        # Add violin plots for fluency_metric
        fluency_violin_data = []
        for i, result_set in enumerate(result_sets):
            result_set_data = df[df["result_set"] == result_set]
            values = result_set_data[fluency_metric].dropna().tolist()

            if values:
                violin = {
                    "type": "violin",
                    "y": values,
                    "name": result_set,
                    "box": {"visible": True},
                    "meanline": {"visible": True},
                    "line": {"color": colors[i % len(colors)]},
                    "fillcolor": "",
                    "opacity": 0.6,
                    "points": "outliers",
                }
                fluency_violin_data.append(violin)

        fluency_violin_layout = {
            "title": f"Distribution of {fluency_metric.replace('_', ' ').title()} by Result Set",
            "yaxis": {"title": fluency_metric.replace("_", " ").title()},
            "violingap": 0,
            "violingroupgap": 0,
            "violinmode": "group",
            "legend": {"orientation": "h", "y": -0.2},
        }

        charts_data[f"{fluency_metric}_violin"] = {
            "data": fluency_violin_data,
            "layout": fluency_violin_layout,
        }

    return charts_data


def generate_html_report(
    df: pd.DataFrame,
    numeric_metrics: List[str],
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
) -> str:
    """Generate an interactive HTML report using Jinja2 templating"""
    # Set up Jinja environment
    current_dir = os.path.dirname(os.path.abspath(__file__))
    env = Environment(loader=FileSystemLoader(os.path.join(current_dir, "templates")))
    template = env.get_template("report_template.html")

    # Create chart data
    charts_data = create_charts_data(
        df, target_metric=target_metric, fluency_metric=fluency_metric
    )

    # Format metric names for display
    target_display = target_metric.replace("_", " ")
    fluency_display = fluency_metric.replace("_", " ")

    # Get list of unique examples, sorted by idx
    unique_idxs = sorted(df["idx"].unique())

    # Prepare data for each example
    examples = []
    for idx in unique_idxs:
        idx_data = df[df["idx"] == idx]

        # Extract info data directly from the DataFrame
        info_data = None
        if "info" in idx_data.columns and not idx_data["info"].isna().all():
            first_info = idx_data["info"].iloc[0] if not idx_data.empty else None
            if first_info:
                info_data = first_info

        # Create scatter plot for this example
        scatter_plot = None
        if target_metric in idx_data.columns and fluency_metric in idx_data.columns:
            # Filter to valid data points
            valid_data = idx_data.dropna(subset=[target_metric, fluency_metric])

            if not valid_data.empty:
                # Create scatter plot data
                scatter_data = []

                # Define a colorscale for consistent colors across charts
                colors = [
                    "#1f77b4",
                    "#ff7f0e",
                    "#2ca02c",
                    "#d62728",
                    "#9467bd",
                    "#8c564b",
                    "#e377c2",
                    "#7f7f7f",
                    "#bcbd22",
                    "#17becf",
                ]

                for i, (result_set, result_data) in enumerate(
                    valid_data.groupby("result_set")
                ):
                    x_values = result_data[fluency_metric].tolist()
                    y_values = result_data[target_metric].tolist()

                    scatter = {
                        "x": x_values,
                        "y": y_values,
                        "mode": "markers",
                        "type": "scatter",
                        "name": result_set,
                        "marker": {
                            "color": colors[i % len(colors)],
                            "size": 10,
                            "opacity": 0.7,
                        },
                        "text": [
                            f"Result #{i}<br>{fluency_metric}: {x:.4f}<br>{target_metric}: {y:.4f}<br>{result_data.iloc[i].get('input_text', '') if 'input_text' in result_data.columns else ''}"
                            for i, (x, y) in enumerate(zip(x_values, y_values))
                        ],
                        "hoverinfo": "text+name",
                    }
                    scatter_data.append(scatter)

                # Create the scatter plot layout
                scatter_layout = {
                    "title": f"{fluency_display} vs {target_display}",
                    "xaxis": {"title": fluency_display},
                    "yaxis": {"title": target_display},
                    "height": 300,
                    "margin": {"l": 60, "r": 30, "b": 60, "t": 40, "pad": 4},
                    "hovermode": "closest",
                }

                scatter_plot = {"data": scatter_data, "layout": scatter_layout}

        # Process result sets for this example
        result_sets = []
        for result_set in idx_data["result_set"].unique():
            result_set_data = idx_data[idx_data["result_set"] == result_set]

            # Determine column order: exclude specific DataFrame-only columns
            exclude_cols = {"idx", "result_set", "result_index", "info"}
            all_columns = [
                col for col in result_set_data.columns if col not in exclude_cols
            ]

            # Simply pass the raw data rows as dictionaries
            generations = []
            for _, row in result_set_data.iterrows():
                row_dict = {}
                for col in all_columns:
                    if col in row and not pd.isna(row[col]):
                        row_dict[col] = row[col]
                generations.append(row_dict)

            result_sets.append(
                {"name": result_set, "generations": generations, "columns": all_columns}
            )

        examples.append(
            {
                "idx": int(idx),
                "result_sets": result_sets,
                "scatter_plot": scatter_plot,
                "info_data": info_data,
            }
        )

    # Render the template with all the data
    return template.render(
        examples=examples,
        charts_data=charts_data,
        numeric_metrics=numeric_metrics,
        available_result_sets=sorted(df["result_set"].unique()),
        current_datetime=datetime.now(),
        target=target_metric,
        target_display=target_display,
        fluency=fluency_metric,
        fluency_display=fluency_display,
    )


def filter_table(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
) -> pd.DataFrame:
    """
    Filter results to keep only the point with maximum target metric value
    within the fluency range of 3-9 for each example and result set.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to maximize
        fluency_metric: The fluency metric to filter on

    Returns:
        Filtered DataFrame with one entry per example per result set
    """
    filtered_rows = []

    # Group by idx and result_set
    for (idx, result_set), group_df in df.groupby(["idx", "result_set"]):
        # Filter to keep only rows where fluency is between 3 and 9
        valid_range = group_df[
            (group_df[fluency_metric] >= 3) & (group_df[fluency_metric] < 9)
        ]

        # If we have rows in the valid range, keep the one with max target value
        if not valid_range.empty:
            # Get the row with the maximum target metric
            max_target_row = valid_range.loc[valid_range[target_metric].idxmax()]
            filtered_rows.append(max_target_row)

    # Create a new DataFrame from the filtered rows
    if filtered_rows:
        return pd.DataFrame(filtered_rows)
    else:
        # Return empty DataFrame with same columns if no rows match criteria
        return pd.DataFrame(columns=df.columns)


def main():
    parser = argparse.ArgumentParser(
        description="Generate interactive dashboard for TinyStories evaluation results"
    )
    parser.add_argument(
        "--input",
        "-i",
        type=str,
        default="evaluation_results.json",
        help="Path to input JSON file with evaluation results",
    )
    parser.add_argument(
        "--output",
        "-o",
        type=str,
        default="evaluation_report.html",
        help="Path to output HTML report",
    )

    parser.add_argument(
        "--target",
        type=str,
        default="logit_diff_improvement",
        help="Target metric to use for charts (default: logit_diff_improvement)",
    )
    parser.add_argument(
        "--fluency",
        type=str,
        default="cross_entropy",
        help="Fluency metric to use for charts (default: cross_entropy)",
    )

    parser.add_argument(
        "--filter_table",
        action="store_true",
        help="Filter results to keep only the max target between fluency 3 and 9",
    )

    args = parser.parse_args()

    # Load and process the data
    print(f"Loading data from {args.input}")
    try:
        all_results = load_results(args.input)

        df = flatten_results(all_results)

        if args.filter_table:
            print("Filtering to max target between fluency 3 and 9...")
            df = filter_table(
                df, target_metric=args.target, fluency_metric=args.fluency
            )
            print(f"Filtered to {len(df)} rows")

        # Get numeric metrics only
        numeric_metrics = get_numeric_metrics(df)
        print(
            f"Found {len(numeric_metrics)} numeric metrics: {', '.join(numeric_metrics)}"
        )

        # Generate report
        print("Generating HTML report...")
        html_content = generate_html_report(
            df, numeric_metrics, target_metric=args.target, fluency_metric=args.fluency
        )

        # Save the report
        with open(args.output, "w", encoding="utf-8") as f:
            f.write(html_content)

        print(f"Report saved to {args.output}")

    except Exception as e:
        import traceback

        print(f"Error: {str(e)}")
        print(traceback.format_exc())
        return 1

    return 0


if __name__ == "__main__":
    main()
