import dataclasses
import json
from typing import List

import einops
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import wandb


@dataclasses.dataclass
class HistoryColumns:
    """Constants for column names used in history DataFrames."""

    ITERATION = "iteration"
    POPULATION = "population"
    CHILD = "child"
    TEXT = "text"
    TOKEN_IDS = "token_ids"
    TOKEN_DISPLAY = "token_display"
    TARGET = "target"
    XENTROPY = "xentropy"
    PARENT_ITERATION = "parent_iteration"
    PARENT_POPULATION = "parent_population"
    PARENT_CHILD = "parent_child"
    X_WEIGHT = "x_weight"
    CHANGES = "changes"


# This is the original History class from the EPO code. Consider using to_dataframe!
@dataclasses.dataclass
class History:
    """
    The `epo` function returns a History object that contains the full history
    of the population members at each iteration.
    """

    # The token ids for each population member at each iteration.
    ids: List = dataclasses.field(default_factory=lambda: [])
    # The cross-entropy loss for each population member at each iteration.
    xentropy: List = dataclasses.field(default_factory=lambda: [])
    # The target objective for each population member at each iteration.
    target: List = dataclasses.field(default_factory=lambda: [])
    # The indices of the population members that were retained at each iteration.
    keep: List = dataclasses.field(default_factory=lambda: [])
    # The runtime for each iteration.
    runtime: List = dataclasses.field(default_factory=lambda: [])
    # The X weights used for selection at each iteration
    x_weights: List = dataclasses.field(default_factory=lambda: [])
    # Population size used in the EPO run
    pop_size: int = None
    # Number of children per parent
    explore_per_pop: int = None

    def subset(self, slc):
        """
        Return a History object sliced along the iterations dimension.
        """
        return History(
            self.ids[slc],
            self.xentropy[slc],
            self.target[slc],
            self.keep[slc],
            self.runtime[slc],
            self.x_weights[slc],
            self.pop_size,
            self.explore_per_pop,
        )

    def _insert(self, new_ids, target, xentropy, keep, runtime, x_weights):
        self.ids.append(new_ids.cpu().numpy())
        self.target.append(target.cpu().numpy())
        self.xentropy.append(xentropy.cpu().numpy())
        self.keep.append(keep.cpu().numpy())
        self.runtime.append(runtime)
        self.x_weights.append(x_weights.cpu().numpy())

    def _finalize(self):
        self.ids = np.stack(self.ids, axis=0)
        self.target = np.stack(self.target, axis=0)
        self.xentropy = np.stack(self.xentropy, axis=0)
        self.keep = np.stack(self.keep, axis=0)
        self.runtime = np.array(self.runtime)
        self.x_weights = np.stack(self.x_weights, axis=0)

    def reshape_by_parent(self, array, pop_size, explore_per_pop):
        """Reshape array from [iter, samples, ...] to [pop, iter, children, ...]"""
        pattern = "iter (child pop) ..." if len(array.shape) > 2 else "iter (child pop)"
        target = "pop iter child ..." if len(array.shape) > 2 else "pop iter child"
        return einops.rearrange(
            array, pattern + " -> " + target, pop=pop_size, child=explore_per_pop
        )

    def group_by_parent(self):
        """Group history data by parent population member"""

        return {
            "ids": self.reshape_by_parent(
                self.ids, self.pop_size, self.explore_per_pop + 1
            ),
            "target": self.reshape_by_parent(
                self.target, self.pop_size, self.explore_per_pop + 1
            ),
            "xentropy": self.reshape_by_parent(
                self.xentropy, self.pop_size, self.explore_per_pop + 1
            ),
            "keep": self.keep,
            "runtime": self.runtime,
        }

    def to_dataframe(self, tokenizer, iter=None, child=None):
        """
        Convert history to a pandas DataFrame with decoded text.

        Parameters
        ----------
        tokenizer : transformers.PreTrainedTokenizer
            The tokenizer to use for decoding token IDs
        iter : int, optional
            If provided, only process this specific iteration (default: None)
        child : int, optional
            If provided, only process this specific child index (default: None)

        Returns
        -------
        pandas.DataFrame
            DataFrame with iteration, population, child, text, token_ids, target, xentropy,
            parent_iteration, parent_population, parent_child columns
        """
        return history_to_dataframe(self, tokenizer, iter=iter, child=child)


def history_to_dataframe(
    history,
    tokenizer,
    save_to_jsonl=False,
    jsonl_path="epo_history.jsonl",
    iter=None,
    child=None,
):
    """
    Convert search history to a pandas DataFrame with decoded text.

    Parameters
    ----------
    history : History-like object
        The history object containing search results
    tokenizer : transformers.PreTrainedTokenizer
        The tokenizer to use for decoding token IDs
    save_to_jsonl : bool, optional
        If True, save the history data to a JSONL file
    jsonl_path : str, optional
        Path to save the JSONL file (default: "epo_history.jsonl")
    iter : int, optional
        If provided, only process this specific iteration (default: None)
    child : int, optional
        If provided, only process this specific child index (default: None)

    Returns
    -------
    pandas.DataFrame
        DataFrame with iteration, population, child, text, token_ids, target, xentropy,
        parent_iteration, parent_population, parent_child columns and other relevant data
    """
    # Create lists to store data
    data = []
    if iter < 0:
        iter = len(history.runtime) + iter
    # If pop_size and explore_per_pop are set, group by parent
    grouped = history.group_by_parent()

    # Process keep array to determine parent sources
    for pop_idx in range(history.pop_size):
        # If iter is specified, only process that iteration
        iter_range = (
            [iter]
            if iter is not None and 0 <= iter < len(history.runtime)
            else range(len(history.runtime))
        )

        for iter_idx in iter_range:
            # If child is specified, only process that child
            child_range = (
                [child]
                if child is not None and 0 <= child <= history.explore_per_pop
                else range(history.explore_per_pop + 1)
            )

            for child_idx in child_range:
                # Get token IDs for this sample
                token_ids = grouped["ids"][pop_idx, iter_idx, child_idx]
                # Decode to text
                text = tokenizer.decode(token_ids)
                # Convert token IDs to a readable format
                token_ids_str = [str(int(tid)) for tid in token_ids]
                token_strs = [tokenizer.decode([int(tid)]) for tid in token_ids]
                tokens_display = [
                    f"{tid}({ts})" for tid, ts in zip(token_ids_str, token_strs)
                ]

                # Determine parent source
                parent_iter = -1
                parent_pop = -1
                parent_child = -1
                keep_x_weight = -1

                # For first iteration, there's no parent
                if iter_idx > 0 and child_idx == 0:
                    # This is a parent, so look up where it came from in the previous iteration
                    prev_keep_idx = history.keep[iter_idx - 1][pop_idx]
                    claimed_prev = history.ids[iter_idx - 1][prev_keep_idx]

                    # TODO make this clear with a -1 when it's created by model helper!
                    # assert (claimed_prev == token_ids).all()

                    # Calculate which population member and child this was in the previous iteration
                    # prev_keep_idx is an index into the flattened array of all members

                    prev_child = prev_keep_idx // (history.pop_size)
                    prev_pop = prev_keep_idx % (history.pop_size)

                    parent_iter = iter_idx - 1
                    parent_pop = prev_pop
                    parent_child = prev_child
                    keep_x_weight = float(history.x_weights[iter_idx - 1][pop_idx])

                # Default value for changes
                changes = "None"

                # For children (child_idx > 0), compare with parent to find changes
                if child_idx > 0:
                    parent_ids = grouped["ids"][pop_idx, iter_idx, 0]
                    changes = ""

                    for pos, (parent_id, current_id) in enumerate(
                        zip(parent_ids, token_ids)
                    ):
                        if parent_id != current_id:
                            parent_token = tokenizer.decode([int(parent_id)])
                            current_token = tokenizer.decode([int(current_id)])
                            changes = f"pos {pos}: {parent_id}({parent_token}) → {current_id}({current_token})"
                            break

                entry = {
                    HistoryColumns.ITERATION: iter_idx,
                    HistoryColumns.POPULATION: pop_idx,
                    HistoryColumns.CHILD: child_idx,
                    HistoryColumns.TEXT: text,
                    HistoryColumns.TOKEN_IDS: token_ids_str,
                    HistoryColumns.TOKEN_DISPLAY: tokens_display,
                    HistoryColumns.TARGET: grouped["target"][
                        pop_idx, iter_idx, child_idx
                    ],
                    HistoryColumns.XENTROPY: grouped["xentropy"][
                        pop_idx, iter_idx, child_idx
                    ],
                    HistoryColumns.PARENT_ITERATION: parent_iter,
                    HistoryColumns.PARENT_POPULATION: parent_pop,
                    HistoryColumns.PARENT_CHILD: parent_child,
                    HistoryColumns.X_WEIGHT: keep_x_weight,
                    HistoryColumns.CHANGES: changes,
                }

                data.append(entry)

                # Check parent-child text consistency if this entry has a parent
                if parent_iter >= 0:
                    # Search through existing data for the parent
                    for parent_entry in data:
                        if (
                            parent_entry[HistoryColumns.ITERATION] == parent_iter
                            and parent_entry[HistoryColumns.POPULATION] == parent_pop
                            and parent_entry[HistoryColumns.CHILD] == parent_child
                        ):
                            break

    # Save data to JSONL if needed
    if save_to_jsonl:
        with open(jsonl_path, "w") as f:
            for entry in data:
                # Convert any numpy or torch values to Python native types
                clean_entry = {}
                for k, v in entry.items():
                    if hasattr(v, "tolist"):
                        clean_entry[k] = v.tolist()
                    elif (
                        isinstance(v, (list, tuple))
                        and len(v) > 0
                        and hasattr(v[0], "tolist")
                    ):
                        clean_entry[k] = [
                            item.tolist() if hasattr(item, "tolist") else item
                            for item in v
                        ]
                    else:
                        clean_entry[k] = v

                f.write(json.dumps(clean_entry) + "\n")

        print(f"Saved history data to {jsonl_path}")

    # Create DataFrame
    return pd.DataFrame(data)


## Rest of the file is datafram helper functions


def filter_to_ancestry_path(df, subset_conditions):
    """
    Filter dataframe to only include nodes in the subset and their ancestors.

    Parameters
    ----------
    df : pandas.DataFrame
        The full history dataframe
    subset_conditions : dict
        Dictionary of conditions to filter the subset, e.g.,
        {'iteration': 10, 'population': 2} or {'target': ('>',0.5)}

    Returns
    -------
    pandas.DataFrame
        Filtered dataframe containing only the subset and ancestor nodes
    """
    import pandas as pd  # Make sure pandas is imported

    # Create a copy to avoid modifying the original
    filtered_df = df.copy()

    # Step 1: Filter to get the initial subset
    subset = filtered_df.copy()
    for key, value in subset_conditions.items():
        if isinstance(value, tuple) and len(value) == 2:
            # Handle comparison operators
            operator, threshold = value
            if operator == ">":
                subset = subset[subset[key] > threshold]
            elif operator == "<":
                subset = subset[subset[key] < threshold]
            elif operator == ">=":
                subset = subset[subset[key] >= threshold]
            elif operator == "<=":
                subset = subset[subset[key] <= threshold]
            elif operator == "==":
                subset = subset[subset[key] == threshold]
            elif operator == "!=":
                subset = subset[subset[key] != threshold]
        else:
            # Direct equality comparison
            subset = subset[subset[key] == value]

    if subset.empty:
        print("No nodes match the subset conditions")
        return pd.DataFrame()

    # Step 2: Find all ancestors recursively
    included_rows = set()

    # Function to record a row's identifier
    def add_row_to_included(row):
        row_id = (
            row[HistoryColumns.ITERATION],
            row[HistoryColumns.POPULATION],
            row[HistoryColumns.CHILD],
        )
        included_rows.add(row_id)
        return row_id

    # Add all rows from the subset to our included set
    for _, row in subset.iterrows():
        add_row_to_included(row)

    # Function to find ancestors recursively
    def add_ancestors(row):
        # Get the parent of this row
        parent_row = get_parent(df, row)

        # If parent exists and hasn't been included yet, include it and find its ancestors
        if parent_row is not None:
            row_id = (
                parent_row[HistoryColumns.ITERATION],
                parent_row[HistoryColumns.POPULATION],
                parent_row[HistoryColumns.CHILD],
            )

            # Check if this parent is already in our set before adding it
            if row_id not in included_rows:
                included_rows.add(row_id)
                # Recursively find ancestors of this parent
                add_ancestors(parent_row)

    # Add ancestors for each row in our subset
    for _, row in subset.iterrows():
        add_ancestors(row)

    # Step 3: Filter the dataframe to only include rows in our set
    mask = filtered_df.apply(
        lambda row: (
            row[HistoryColumns.ITERATION],
            row[HistoryColumns.POPULATION],
            row[HistoryColumns.CHILD],
        )
        in included_rows,
        axis=1,
    )
    filtered_df = filtered_df[mask]

    return filtered_df


def get_parent(df, row):
    """
    Get the parent row for a given node.

    Parameters
    ----------
    df : pandas.DataFrame
        The history dataframe
    row : pandas.Series or dict-like
        A row from the history dataframe or a dict-like object with the same keys

    Returns
    -------
    pandas.Series or None
        The parent row if a parent exists, or None if no parent exists (e.g., for initial nodes)
    """
    # Get coordinates from the row
    iteration = row[HistoryColumns.ITERATION]
    population = row[HistoryColumns.POPULATION]
    child = row[HistoryColumns.CHILD]

    # If it's a mutation, its parent is the corresponding parent node in same iteration
    if child > 0:
        parent_coords = (iteration, population, 0)
        parent_rows = df[
            (df[HistoryColumns.ITERATION] == parent_coords[0])
            & (df[HistoryColumns.POPULATION] == parent_coords[1])
            & (df[HistoryColumns.CHILD] == parent_coords[2])
        ]
        if not parent_rows.empty:
            return parent_rows.iloc[0]

    # If it's a parent node from a later iteration, check if it has parent information
    elif iteration > 0:
        parent_iter = row[HistoryColumns.PARENT_ITERATION]
        parent_pop = row[HistoryColumns.PARENT_POPULATION]
        parent_child = row[HistoryColumns.PARENT_CHILD]

        if parent_iter >= 0 and parent_pop >= 0 and parent_child >= 0:
            parent_rows = df[
                (df[HistoryColumns.ITERATION] == parent_iter)
                & (df[HistoryColumns.POPULATION] == parent_pop)
                & (df[HistoryColumns.CHILD] == parent_child)
            ]
            if not parent_rows.empty:
                return parent_rows.iloc[0]

    # No parent found
    return None


def get_mutations(df, parent_row):
    """
    Get all mutation rows for a given parent node.

    Parameters
    ----------
    df : pandas.DataFrame
        The history dataframe
    parent_row : pandas.Series or dict-like
        A row from the history dataframe representing the parent

    Returns
    -------
    pandas.DataFrame
        A dataframe containing all mutation rows for the given parent
    """
    # Get coordinates from the parent row
    iteration = parent_row[HistoryColumns.ITERATION]
    population = parent_row[HistoryColumns.POPULATION]
    child = parent_row[HistoryColumns.CHILD]

    # The parent should have child == 0
    if child != 0:
        return None

    # Find all rows that are mutations of this parent
    # (same iteration and population, but child > 0)
    mutations = df[
        (df[HistoryColumns.ITERATION] == iteration)
        & (df[HistoryColumns.POPULATION] == population)
        & (df[HistoryColumns.CHILD] > 0)
    ]

    return mutations


def get_children(df, row):
    """
    Get all children for a given node, including both mutations and nodes in the
    next iteration that selected this node as their parent.

    Parameters
    ----------
    df : pandas.DataFrame
        The history dataframe
    row : pandas.Series or dict-like
        A row from the history dataframe representing the node

    Returns
    -------
    dict of pandas.DataFrame
        A dictionary with two keys:
        - 'mutations': DataFrame containing mutation children (same iteration)
        - 'selected': DataFrame containing nodes in the next iteration that selected this node
    """
    # Get coordinates from the row
    iteration = row[HistoryColumns.ITERATION]
    population = row[HistoryColumns.POPULATION]
    child = row[HistoryColumns.CHILD]

    results = {}

    # 1. Find mutation children (same iteration, same population, child > 0)
    if child == 0:  # Only parent nodes can have mutations
        results["mutations"] = df[
            (df[HistoryColumns.ITERATION] == iteration)
            & (df[HistoryColumns.POPULATION] == population)
            & (df[HistoryColumns.CHILD] > 0)
        ]
    else:
        results["mutations"] = df.head(0)  # Empty DataFrame with same columns

    # 2. Find nodes in the next iteration that selected this node as their parent
    # These have PARENT_ITERATION, PARENT_POPULATION, PARENT_CHILD matching this node's coordinates
    results["selected"] = df[
        (df[HistoryColumns.PARENT_ITERATION] == iteration)
        & (df[HistoryColumns.PARENT_POPULATION] == population)
        & (df[HistoryColumns.PARENT_CHILD] == child)
    ]

    return results


def get_pareto_frontier_df(df, target_col=None, xentropy_col=None):
    """
    Filter a DataFrame to keep only the points on the Pareto frontier.

    The Pareto frontier consists of all points where no other point has
    both better target value and better cross-entropy value.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing history data
    target_col : str, optional
        Name of the target column (default: HistoryColumns.TARGET)
    xentropy_col : str, optional
        Name of the cross-entropy column (default: HistoryColumns.XENTROPY)

    Returns
    -------
    pd.DataFrame
        DataFrame containing only the rows that are on the Pareto frontier
    """
    if target_col is None:
        target_col = HistoryColumns.TARGET

    if xentropy_col is None:
        xentropy_col = HistoryColumns.XENTROPY

    # Early return for empty dataframes
    if df.empty:
        return df.copy()

    # Extract the columns we need as numpy arrays for faster processing
    xentropy_values = df[xentropy_col].values
    target_values = df[target_col].values

    # Get the indices that would sort xentropy values (ascending)
    sorted_indices = np.argsort(xentropy_values)

    # Initialize with an array of True values
    is_pareto = np.ones(len(df), dtype=bool)

    # Keep track of the best target value seen so far
    # Since we're sorting by ascending xentropy, a point is on the Pareto frontier
    # if its target is better than all points with better (lower) xentropy seen so far
    best_target_so_far = float("-inf")

    # Process points in order of increasing xentropy
    for idx in sorted_indices:
        current_target = target_values[idx]

        # If this point's target is not better than the best seen so far,
        # it's dominated and not on the Pareto frontier
        if current_target <= best_target_so_far:
            is_pareto[idx] = False
        else:
            # Update the best target seen so far
            best_target_so_far = current_target

    # Return only the rows that are on the Pareto frontier
    return df.loc[is_pareto]


def pretty_print_df(
    df,
    max_rows=20,
    max_text_length=50,
    columns=None,
    highlight_best=True,
    float_format="{:.4f}",
    print_output=True,
    sort_by=None,
):
    """
    Pretty print a DataFrame with custom formatting for history data.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to print
    max_rows : int, optional
        Maximum number of rows to display per iteration (default: 20)
    max_text_length : int, optional
        Maximum length for text fields before truncation (default: 50)
    columns : list, optional
        List of columns to include (default: select common history columns)
    highlight_best : bool, optional
        Whether to highlight the best target and xentropy values (default: True)
    float_format : str, optional
        Format string for floating point numbers (default: "{:.4f}")
    print_output : bool, optional
        If True, print the result; if False, return the formatted string (default: True)
    sort_by : str or list, optional
        Column(s) to sort by before display (default: None)

    Returns
    -------
    str or None
        If print_output is False, returns the formatted string; otherwise None
    """
    # Get unique iterations to organize display
    iterations = sorted(df[HistoryColumns.ITERATION].unique())

    # Find best values for highlighting
    if highlight_best:
        best_target_idx = df[HistoryColumns.TARGET].idxmax()
        best_target_value = df.loc[best_target_idx, HistoryColumns.TARGET]
        best_xentropy_idx = df[HistoryColumns.XENTROPY].idxmin()
        best_xentropy_value = df.loc[best_xentropy_idx, HistoryColumns.XENTROPY]

    # Initialize string buffer for output
    output = []

    # Get terminal width for formatting
    try:
        import shutil

        term_width = shutil.get_terminal_size().columns
    except:
        term_width = 120  # Default width if can't determine

    # Define simple text styling functions
    def bold(text):
        return f"\033[1m{text}\033[0m"

    def green(text):
        return f"\033[32m{text}\033[0m"

    def blue(text):
        return f"\033[34m{text}\033[0m"

    def yellow(text):
        return f"\033[33m{text}\033[0m"

    def red(text):
        return f"\033[31m{text}\033[0m"

    # Process each iteration
    for iter_num in iterations:
        # Header for this iteration
        iter_header = f"{'=' * 20} ITERATION {iter_num} {'=' * 20}"
        output.append(bold(iter_header))

        # Get data for this iteration
        iter_df = df[df[HistoryColumns.ITERATION] == iter_num]

        # Get unique populations in this iteration
        populations = sorted(iter_df[HistoryColumns.POPULATION].unique())

        # Process each population
        for pop_num in populations:
            pop_df = iter_df[iter_df[HistoryColumns.POPULATION] == pop_num]

            # Header for this population
            pop_header = f"{'-' * 15} Population {pop_num} {'-' * 15}"
            output.append(bold(pop_header))

            # First show parent (child 0)
            parent = pop_df[pop_df[HistoryColumns.CHILD] == 0]

            if not parent.empty:
                parent_row = parent.iloc[0]

                # Format parent information
                parent_target = float_format.format(parent_row[HistoryColumns.TARGET])
                parent_xentropy = float_format.format(
                    parent_row[HistoryColumns.XENTROPY]
                )

                # Highlight if best
                if highlight_best:
                    if parent_row[HistoryColumns.TARGET] == best_target_value:
                        parent_target = green(parent_target + " [BEST]")
                    if parent_row[HistoryColumns.XENTROPY] == best_xentropy_value:
                        parent_xentropy = blue(parent_xentropy + " [BEST]")

                # Show parent's parent if available
                parent_info = f"PARENT (C0) | T: {parent_target} | X: {parent_xentropy}"

                if parent_row[HistoryColumns.PARENT_ITERATION] >= 0:
                    parent_info += f" | From: Iter {parent_row[HistoryColumns.PARENT_ITERATION]}.Pop {parent_row[HistoryColumns.PARENT_POPULATION]}.C{parent_row[HistoryColumns.PARENT_CHILD]}"
                    if parent_row[HistoryColumns.X_WEIGHT] >= 0:
                        parent_info += f" | X-Weight: {float_format.format(parent_row[HistoryColumns.X_WEIGHT])}"

                output.append(yellow(parent_info))

                # Show parent text
                text = parent_row[HistoryColumns.TEXT]
                output.append(f"Text: {text}")

                # Show token info if available
                if HistoryColumns.TOKEN_DISPLAY in parent_row:
                    output.append(f"Tokens: {parent_row[HistoryColumns.TOKEN_DISPLAY]}")

                # Add a separator
                output.append("-" * min(len(text) + 6, term_width))

            # Then show children
            children = pop_df[pop_df[HistoryColumns.CHILD] > 0]

            if not children.empty:
                # Sort children by target value
                children = children.sort_values(
                    by=HistoryColumns.TARGET, ascending=False
                )

                # Limit number of children displayed if too many
                if len(children) > max_rows:
                    output.append(
                        f"Showing top {max_rows} of {len(children)} children..."
                    )
                    children = children.head(max_rows)

                # Display each child
                for _, child_row in children.iterrows():
                    child_num = child_row[HistoryColumns.CHILD]
                    child_target = float_format.format(child_row[HistoryColumns.TARGET])
                    child_xentropy = float_format.format(
                        child_row[HistoryColumns.XENTROPY]
                    )

                    # Highlight if best
                    if highlight_best:
                        if child_row[HistoryColumns.TARGET] == best_target_value:
                            child_target = green(child_target + " [BEST]")
                        if child_row[HistoryColumns.XENTROPY] == best_xentropy_value:
                            child_xentropy = blue(child_xentropy + " [BEST]")

                    child_info = f"CHILD (C{child_num}) | T: {child_target} | X: {child_xentropy}"
                    output.append(yellow(child_info))

                    # Show changes if available
                    if child_row[HistoryColumns.CHANGES] != "None":
                        output.append(
                            red(f"Changes: {child_row[HistoryColumns.CHANGES]}")
                        )

                    # Show child text
                    text = child_row[HistoryColumns.TEXT]
                    output.append(f"Text: {text}")

                    # Add a separator between children
                    output.append("-" * min(len(text) + 6, term_width))

            # Add extra space between populations
            output.append("")

        # Add extra space between iterations
        output.append("\n")

    # Combine all output into a single string
    result = "\n".join(output)

    if print_output:
        print(result)
        return None
    else:
        return result


def simple_print_df(
    df, print_output=True, ascending=True, sort_by=HistoryColumns.XENTROPY
):
    """
    Print a simple list showing entropy, target values, and text from a DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing history data
    print_output : bool, optional
        If True, print the result; if False, return the formatted string (default: True)
    ascending : bool, optional
        If True, sort values in ascending order (default: True)
    sort_by : str, optional
        Column to sort by, either "entropy" or "target" (default: "entropy")

    Returns
    -------
    str or None
        If print_output is False, returns the formatted string; otherwise None
    """
    # Initialize output
    output = []

    # Sort DataFrame by the chosen column
    sorted_df = df.sort_values(by=sort_by, ascending=ascending)

    # Add each row
    for _, row in sorted_df.iterrows():
        entropy = row[HistoryColumns.XENTROPY]
        target = row[HistoryColumns.TARGET]
        text = row[HistoryColumns.TEXT]

        # Display entropy and target on first line
        output.append(f"Entropy: {entropy:.4f} | Target: {target:.4f}")

        # Display text on second line
        output.append(f"Text: {text}")

        # Add separator between entries
        output.append("-" * 40)

    # Combine output
    result = "\n".join(output)

    if print_output:
        print(result)
        return None
    else:
        return result


def plot_target_vs_entropy(
    dfs,
    ax=None,
    color_by=None,
    size_by=None,
    highlight_pareto=True,
    save_path=None,
    names=None,
    title="",
):
    """
    Plot target vs cross-entropy values from one or more history DataFrames.

    Parameters
    ----------
    dfs : pd.DataFrame or list of pd.DataFrame
        Single DataFrame or list of DataFrames containing history data
    ax : matplotlib.axes.Axes, optional
        Axes to plot on. If None, creates new figure and axes
    color_by : str, optional
        Column name to use for point colors (default: None)
    size_by : str, optional
        Column name to use for point sizes (default: None)
    highlight_pareto : bool, optional
        Whether to highlight the Pareto frontier (default: True)
    save_path : str, optional
        If provided, save the plot to this path (default: None)
    names : list of str, optional
        Names for each DataFrame to show in legend (default: None)
    title : str, optional
        Title for the plot (default: "Target vs Cross-Entropy")

    Returns
    -------
    matplotlib.axes.Axes
        The axes containing the plot
    """
    import matplotlib.pyplot as plt

    # Convert single DataFrame to list
    if not isinstance(dfs, list):
        dfs = [dfs]

    # Create default names if not provided
    if names is None:
        names = [f"Dataset {i + 1}" for i in range(len(dfs))]
    elif len(names) < len(dfs):
        names = names + [f"Dataset {i + 1}" for i in range(len(names), len(dfs))]

    # Different markers for each dataset
    markers = ["o", "s", "^", "D", "v", "<", ">", "p", "*", "h"]

    # Create new figure if no axes provided
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))
    else:
        fig = ax.figure

    # Plot each DataFrame
    for i, (df, name) in enumerate(zip(dfs, names)):
        # Basic scatter plot
        scatter_kwargs = {
            "alpha": 0.6,
            "marker": markers[i % len(markers)],
            "label": name,
        }

        if color_by is not None:
            scatter_kwargs["c"] = df[color_by]

        if size_by is not None:
            # Scale sizes to reasonable range
            sizes = df[size_by]
            scatter_kwargs["s"] = (
                50 * (sizes - sizes.min()) / (sizes.max() - sizes.min()) + 20
            )

        ax.scatter(
            df[HistoryColumns.XENTROPY], df[HistoryColumns.TARGET], **scatter_kwargs
        )

        # Highlight Pareto frontier if requested
        if highlight_pareto:
            pareto_df = get_pareto_frontier_df(df)
            # Sort by x-entropy for proper line connection
            pareto_df = pareto_df.sort_values(by=HistoryColumns.XENTROPY)
            ax.plot(
                pareto_df[HistoryColumns.XENTROPY],
                pareto_df[HistoryColumns.TARGET],
                "--",
                color=f"C{i}",
                label="_nolegend_",
            )
            # Highlight Pareto points but don't add to legend
            ax.scatter(
                pareto_df[HistoryColumns.XENTROPY],
                pareto_df[HistoryColumns.TARGET],
                color=f"C{i}",
                s=100,
                alpha=0.6,
                marker=markers[i % len(markers)],
                label="_nolegend_",
            )

    # Add labels and title
    ax.set_xlabel("Cross-Entropy")
    ax.set_ylabel("Target Value")
    if title:
        ax.set_title(title)

    # Add legend
    ax.legend()

    # Add grid
    ax.grid(True, alpha=0.3)

    # Save plot if path provided
    if save_path is not None:
        fig.savefig(save_path, bbox_inches="tight", dpi=300)

    return ax


def plot_target_vs_entropy_interactive(
    dfs,
    color_by=None,
    size_by=None,
    highlight_pareto=True,
    title="Target vs Cross-Entropy",
    names=None,
    height=600,
    width=900,
    save_html=None,
):
    """
    Create an interactive Plotly plot of target vs cross-entropy values with text on hover.

    Parameters
    ----------
    dfs : pd.DataFrame or list of pd.DataFrame
        Single DataFrame or list of DataFrames containing history data
    color_by : str, optional
        Column name to use for point colors (default: None)
    size_by : str, optional
        Column name to use for point sizes (default: None)
    highlight_pareto : bool, optional
        Whether to highlight the Pareto frontier (default: True)
    title : str, optional
        Title for the plot (default: "Target vs Cross-Entropy")
    names : list of str, optional
        Names for each DataFrame to show in legend (default: None)
    height : int, optional
        Height of the plot in pixels (default: 600)
    width : int, optional
        Width of the plot in pixels (default: 900)
    save_html : str, optional
        If provided, save the interactive plot to this HTML file path (default: None)

    Returns
    -------
    plotly.graph_objs._figure.Figure
        The interactive Plotly figure
    """
    # Convert single DataFrame to list
    if not isinstance(dfs, list):
        dfs = [dfs]

    # Create default names if not provided
    if names is None:
        names = [f"Dataset {i + 1}" for i in range(len(dfs))]
    elif len(names) < len(dfs):
        names = names + [f"Dataset {i + 1}" for i in range(len(names), len(dfs))]

    # Different markers for each dataset
    markers = [
        "circle",
        "square",
        "triangle-up",
        "diamond",
        "triangle-down",
        "triangle-left",
        "triangle-right",
        "pentagon",
        "star",
        "hexagon",
    ]

    colors = [
        "#1f77b4",  # blue
        "#ff7f0e",  # orange
        "#2ca02c",  # green
        "#d62728",  # red
        "#9467bd",  # purple
        "#8c564b",  # brown
        "#e377c2",  # pink
        "#7f7f7f",  # gray
        "#bcbd22",  # yellow-green
        "#17becf",  # cyan
    ]

    # Create figure
    fig = go.Figure()

    # Plot each DataFrame
    for i, (df, name) in enumerate(zip(dfs, names)):
        # Prepare hover text
        hover_text = [
            f"Dataset: {name}<br>"
            f"Text: {text}<br>"
            f"Target: {target:.4f}<br>"
            f"Cross-Entropy: {xentropy:.4f}<br>"
            for text, target, xentropy in zip(
                df[HistoryColumns.TEXT],
                df[HistoryColumns.TARGET],
                df[HistoryColumns.XENTROPY],
            )
        ]

        # Basic scatter plot
        scatter_kwargs = {
            "name": name,
            "marker_symbol": markers[i % len(markers)],
            "mode": "markers",
            "hoverinfo": "text",
            "hovertext": hover_text,
            "marker": {"color": colors[i % len(colors)]},
        }

        if color_by is not None:
            scatter_kwargs["marker"] = {"color": df[color_by], "showscale": True}

        if size_by is not None:
            # Scale sizes to reasonable range
            sizes = df[size_by]
            marker_sizes = 20 + (
                50 * (sizes - sizes.min()) / (sizes.max() - sizes.min())
            )
            scatter_kwargs["marker"]["size"] = marker_sizes

        fig.add_trace(
            go.Scatter(
                x=df[HistoryColumns.XENTROPY],
                y=df[HistoryColumns.TARGET],
                **scatter_kwargs,
            )
        )

        # Highlight Pareto frontier if requested
        if highlight_pareto:
            pareto_df = get_pareto_frontier_df(df)
            # Sort by x-entropy for proper line connection
            pareto_df = pareto_df.sort_values(by=HistoryColumns.XENTROPY)

            # Prepare hover text for Pareto points
            pareto_hover_text = [
                f"PARETO POINT<br>"
                f"Dataset: {name}<br>"
                f"Text: {text}<br>"
                f"Target: {target:.4f}<br>"
                f"Cross-Entropy: {xentropy:.4f}<br>"
                for text, target, xentropy in zip(
                    pareto_df[HistoryColumns.TEXT],
                    pareto_df[HistoryColumns.TARGET],
                    pareto_df[HistoryColumns.XENTROPY],
                )
            ]

            # Get the color from scatter_kwargs
            point_color = scatter_kwargs["marker"]["color"]

            # Add Pareto frontier line
            fig.add_trace(
                go.Scatter(
                    x=pareto_df[HistoryColumns.XENTROPY],
                    y=pareto_df[HistoryColumns.TARGET],
                    mode="lines",
                    line=dict(dash="dash", color=point_color),
                    name=f"{name} Pareto",
                    hoverinfo="skip",
                    showlegend=False,
                )
            )

            # Add Pareto points
            fig.add_trace(
                go.Scatter(
                    x=pareto_df[HistoryColumns.XENTROPY],
                    y=pareto_df[HistoryColumns.TARGET],
                    mode="markers",
                    marker=dict(
                        symbol=markers[i % len(markers)],
                        size=15,
                        line=dict(width=2),
                        color=point_color,
                    ),
                    name=f"{name} Pareto Points",
                    hoverinfo="text",
                    hovertext=pareto_hover_text,
                    showlegend=False,
                )
            )

    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title="Cross-Entropy",
        yaxis_title="Target Value",
        height=height,
        width=width,
        hovermode="closest",
        template="plotly_white",
    )

    # Save to HTML if a path is provided
    if save_html is not None:
        fig.write_html(save_html)
        print(f"Plot saved to {save_html}")

    return fig


def log_dataframe_to_wandb(df, summary_key="data_table"):
    """
    Prepare a DataFrame for wandb logging and add it to the run summary.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame to log to wandb
    summary_key : str, optional
        Key to use in wandb.run.summary (default: "data_table")

    Returns
    -------
    pd.DataFrame
        The serialization-ready DataFrame
    """
    if df is None or len(df) == 0:
        print(f"Warning: No data to log for {summary_key}")
        return df

    print(f"Logging {len(df)} rows to wandb table: {summary_key}")

    # Make a copy to avoid modifying the original
    results_df = df.copy()

    # Convert numpy values to Python native types
    for col in results_df.columns:
        if len(results_df) > 0:
            # Get the type of the first value
            sample_value = results_df[col].iloc[0]

            # Handle numpy numeric types
            if "numpy" in str(type(sample_value)):
                if hasattr(sample_value, "item"):
                    results_df[col] = results_df[col].apply(
                        lambda x: x.item() if hasattr(x, "item") else x
                    )
                else:
                    # Fallback for other numpy types
                    results_df[col] = results_df[col].apply(
                        lambda x: x.tolist() if hasattr(x, "tolist") else x
                    )

            # Handle objects with tolist method (like torch tensors)
            elif hasattr(sample_value, "tolist") and not isinstance(sample_value, list):
                results_df[col] = results_df[col].apply(
                    lambda x: x.tolist() if hasattr(x, "tolist") else x
                )

            # Convert any non-serializable types to strings
            elif not isinstance(
                sample_value, (int, float, str, bool, list, dict, type(None))
            ):
                results_df[col] = results_df[col].apply(str)

    # Log the processed dataframe directly
    table = wandb.Table(dataframe=results_df)
    wandb.log({summary_key: table})

    # Force a sync to ensure the table is uploaded
    wandb.run.log({})
    print(f"Successfully logged {summary_key} table to W&B")

    return results_df
