import json
import os
import uuid

import numpy as np
from IPython.display import HTML, display

from custom_dreamy.history import HistoryColumns, filter_to_ancestry_path


def visualize_population_tree(df, tokenizer, iteration=None, population_idx=None):
    """
    Visualize the population tree showing parents and their children.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame from history.to_dataframe()
    tokenizer : transformers.PreTrainedTokenizer
        Tokenizer for decoding text
    iteration : int, optional
        Specific iteration to visualize. If None, shows the last iteration
    population_idx : int, optional
        Specific population member to visualize. If None, shows all
    """
    # Set default iteration to last iteration if not specified
    if iteration is None:
        iteration = df[HistoryColumns.ITERATION].max()

    # Filter for the specified iteration
    iter_df = df[df[HistoryColumns.ITERATION] == iteration]

    # Get all population indices if not specified
    if population_idx is None:
        pop_indices = sorted(iter_df[HistoryColumns.POPULATION].unique())
    else:
        pop_indices = [population_idx]

    print(f"\n=== ITERATION {iteration} ===\n")

    # For each population member
    for pop_idx in pop_indices:
        # Get parent (child 0)
        parent = iter_df[
            (iter_df[HistoryColumns.POPULATION] == pop_idx)
            & (iter_df[HistoryColumns.CHILD] == 0)
        ]
        if len(parent) == 0:
            continue

        parent_text = parent.iloc[0][HistoryColumns.TEXT]
        parent_target = parent.iloc[0][HistoryColumns.TARGET]
        parent_xentropy = parent.iloc[0][HistoryColumns.XENTROPY]
        parent_tokens = parent.iloc[0][HistoryColumns.TOKEN_DISPLAY]

        # Get parent's source information
        parent_iter = parent.iloc[0][HistoryColumns.PARENT_ITERATION]
        parent_pop = parent.iloc[0][HistoryColumns.PARENT_POPULATION]
        parent_child = parent.iloc[0][HistoryColumns.PARENT_CHILD]

        parent_source_text = ""
        if parent_iter >= 0 and parent_pop >= 0 and parent_child >= 0:
            # Find the source text from previous iterations
            source_entry = df[
                (df[HistoryColumns.ITERATION] == parent_iter)
                & (df[HistoryColumns.POPULATION] == parent_pop)
                & (df[HistoryColumns.CHILD] == parent_child)
            ]
            if not source_entry.empty:
                parent_source_text = source_entry.iloc[0][HistoryColumns.TEXT]

        print(f"\nPOPULATION {pop_idx}:")
        print(f"PARENT: {parent_text}")

        # Print parent source information if available
        if parent_source_text:
            print(
                f"DERIVED FROM: Iteration {parent_iter}, Population {parent_pop}, Child {parent_child}"
            )
            print(f"SOURCE TEXT: {parent_source_text}")
        else:
            if iteration > 0:
                print("DERIVED FROM: [Source information not available]")
            else:
                print("DERIVED FROM: [Initial population]")

        print(f"TOKENS: {' '.join(parent_tokens)}")
        print(f"Target: {parent_target:.4f}, Cross-entropy: {parent_xentropy:.4f}")

        # Get all children (mutations) for this parent
        mutations = iter_df[
            (iter_df[HistoryColumns.POPULATION] == pop_idx)
            & (iter_df[HistoryColumns.CHILD] > 0)
        ]

        if len(mutations) > 0:
            print("\nMUTATIONS FROM THIS PARENT:")
            for _, mutation in mutations.iterrows():
                mutation_tokens = mutation[HistoryColumns.TOKEN_DISPLAY]

                # Find differences between parent and mutation tokens
                diff_indices = []
                for i, (p_token, m_token) in enumerate(
                    zip(parent_tokens, mutation_tokens)
                ):
                    if p_token != m_token:
                        diff_indices.append(i)

                # Highlight the differences in the token display
                highlighted_tokens = []
                for i, token in enumerate(mutation_tokens):
                    if i in diff_indices:
                        highlighted_tokens.append(f"[{token}]")
                    else:
                        highlighted_tokens.append(token)

                print(
                    f"  Mutation {int(mutation[HistoryColumns.CHILD])}: {mutation[HistoryColumns.TEXT]}"
                )
                print(f"  TOKENS: {' '.join(highlighted_tokens)}")
                print(f"  Changed positions: {diff_indices}")
                print(
                    f"  Target: {mutation[HistoryColumns.TARGET]:.4f}, Cross-entropy: {mutation[HistoryColumns.XENTROPY]:.4f}"
                )
                print()


def visualize_ancestry_path(df, tokenizer, subset_conditions, output_path=None):
    """
    Visualize only the ancestry path for nodes matching specific conditions.

    Parameters
    ----------
    df : pandas.DataFrame
        The full history dataframe
    tokenizer : transformers.PreTrainedTokenizer
        Tokenizer for decoding text
    subset_conditions : dict
        Dictionary of conditions to filter the subset
    output_path : str, optional
        Path to save the HTML visualization, if None uses default
    """
    # Filter the dataframe to get only relevant nodes
    filtered_df = filter_to_ancestry_path(df, subset_conditions)

    if filtered_df.empty:
        print("No ancestry path to visualize")
        return

    # Call the interactive visualization with the filtered dataframe
    return visualize_population_tree_interactive(
        filtered_df, tokenizer, output_path=output_path
    )


def visualize_population_tree_interactive(
    df, tokenizer, max_iterations=None, output_path=None
):
    """
    Create an interactive visualization of the population tree showing evolution across iterations.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame from history.to_dataframe()
    tokenizer : transformers.PreTrainedTokenizer
        Tokenizer for decoding text
    max_iterations : int, optional
        Maximum number of iterations to visualize. If None, shows all iterations.
    output_path : str, optional
        Path to save the HTML visualization
    """

    # Custom JSON encoder to handle numpy types
    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, np.integer):
                return int(obj)
            if isinstance(obj, np.floating):
                return float(obj)
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            return super(NumpyEncoder, self).default(obj)

    # Unique ID for this visualization (for internal tracking)
    viz_id = f"viz_{uuid.uuid4().hex[:8]}"

    # Limit iterations if specified
    if max_iterations is not None:
        df = df[df[HistoryColumns.ITERATION] < max_iterations].copy()

    # Get all iterations
    iterations = sorted(df[HistoryColumns.ITERATION].unique())
    max_iter = max(iterations)

    # Prepare data for visualization
    nodes = []
    links = []

    # Create a unique ID for each node
    df["node_id"] = df.apply(
        lambda row: f"i{row[HistoryColumns.ITERATION]}_p{row[HistoryColumns.POPULATION]}_c{row[HistoryColumns.CHILD]}",
        axis=1,
    )

    # Create nodes for each text
    for _, row in df.iterrows():
        target_value = float(row[HistoryColumns.TARGET])
        xentropy_value = float(row[HistoryColumns.XENTROPY])

        # Create node data - convert numpy types to Python native types
        node = {
            "id": row["node_id"],
            "iteration": int(row[HistoryColumns.ITERATION]),
            "population": int(row[HistoryColumns.POPULATION]),
            "child": int(row[HistoryColumns.CHILD]),
            "text": row[HistoryColumns.TEXT],
            "target": target_value,
            "xentropy": xentropy_value,
            "tokens": row[HistoryColumns.TOKEN_DISPLAY],
            "type": "parent" if row[HistoryColumns.CHILD] == 0 else "mutation",
        }

        # Add changes field for mutations
        if row[HistoryColumns.CHILD] > 0:
            node["changes"] = row[HistoryColumns.CHANGES]

        if row[HistoryColumns.CHILD] == 0:
            node["x_weight"] = float(row[HistoryColumns.X_WEIGHT])

        nodes.append(node)

        # Create links from parents to children in the same iteration/population
        if row[HistoryColumns.CHILD] > 0:
            parent_id = (
                f"i{row[HistoryColumns.ITERATION]}_p{row[HistoryColumns.POPULATION]}_c0"
            )
            links.append(
                {"source": parent_id, "target": row["node_id"], "type": "mutation"}
            )

        # Create links from source text to this parent (if it's a parent)
        if row[HistoryColumns.CHILD] == 0 and row[HistoryColumns.ITERATION] > 0:
            parent_iter = row[HistoryColumns.PARENT_ITERATION]
            parent_pop = row[HistoryColumns.PARENT_POPULATION]
            parent_child = row[HistoryColumns.PARENT_CHILD]

            if parent_iter >= 0 and parent_pop >= 0 and parent_child >= 0:
                source_id = f"i{parent_iter}_p{parent_pop}_c{parent_child}"
                links.append(
                    {"source": source_id, "target": row["node_id"], "type": "evolution"}
                )

    # Prepare data and config for visualization
    graph_data = {"nodes": nodes, "links": links}

    config = {
        "maxIteration": int(max_iter),  # Convert to native Python int
        "vizId": viz_id,
    }

    # Get template file, CSS and JS content from files
    current_dir = os.path.dirname(os.path.abspath(__file__))

    template_path = os.path.join(current_dir, "population_tree_template.html")
    css_path = os.path.join(current_dir, "population_tree_style.css")
    js_path = os.path.join(current_dir, "population_tree_script.js")

    with open(template_path, "r") as f:
        template_content = f.read()

    with open(css_path, "r") as f:
        css_content = f.read()

    with open(js_path, "r") as f:
        js_content = f.read()

    # Replace placeholders in the template
    html_content = template_content.replace("{{css_content}}", css_content)
    html_content = html_content.replace(
        "{{data}}", json.dumps(graph_data, cls=NumpyEncoder)
    )
    html_content = html_content.replace(
        "{{config}}", json.dumps(config, cls=NumpyEncoder)
    )
    html_content = html_content.replace("{{js_content}}", js_content)

    # Create HTML file with the specified name
    if output_path is None:
        output_path = os.path.join(current_dir, "population_tree_visualization.html")

    with open(output_path, "w", encoding="utf-8") as f:
        f.write(html_content)

    # Display in notebook if possible
    display(HTML(html_content))

    print(f"Visualization saved to {output_path}")

    # Return text description for non-notebook environments
    iterations_count = len(iterations)
    if iterations_count > 0:
        print(f"Interactive visualization created with {iterations_count} iterations.")
        print(
            "If you're not seeing the interactive visualization, you may be in a non-notebook environment."
        )
        return visualize_population_tree(df, tokenizer)
