import os
from typing import List, Optional

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import umap
from IPython.display import HTML, display
from mpl_toolkits.axes_grid1 import make_axes_locatable
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D

# Assuming .atlas_utils is a local module
# from .atlas_utils import StatsCollector


def _activation_to_rgba(a: float, alpha_floor=0.15) -> str:
    return f"rgba(255, 0, 0, {max(alpha_floor, a):.3f})"


def html_escape(s: str) -> str:
    return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")


def save_neuron_svg(
    neuron_id: int,
    stats: any,
    sae_kit,
    layer_id: int,
    window: int = 8,
    limit: int = 20,
    value_fmt: str = "{:.2f}",
    threshold: float = 6e-3,
    output_dir: str = "neuron_svgs",
):
    """
    Generates and saves a refined SVG visualization of top activating sequences for a neuron.
    """
    raw_entries = stats.top_sequences_for(neuron_id)[:limit]
    if not raw_entries:
        print(f"Neuron {neuron_id} never fired – nothing to show.")
        return

    # --- START OF REFINEMENTS ---
    # SVG constants for layout and styling
    content_height = 30  # The vertical space for the tokens and text
    token_height = 22  # The actual height of the token background rectangle
    sequence_gap = 10  # The vertical margin between each sequence line
    total_line_height = content_height + sequence_gap  # Total space per sequence

    token_padding = 12  # Horizontal padding inside the token
    token_margin = 2  # Horizontal margin between tokens
    font_size = 12
    label_font_size = 6
    # --- END OF REFINEMENTS ---

    svg_lines: List[str] = []
    # Calculate total SVG height based on the number of sequences
    svg_height = total_line_height * len(raw_entries) + 40
    svg_width = 800  # A reasonable default width

    # SVG header
    svg_lines.append(
        f'<svg width="{svg_width}" height="{svg_height}" xmlns="http://www.w3.org/2000/svg">'
    )
    # Added font-weight to .token-text
    svg_lines.append(
        "<style>"
        ".token-text { font-family: monospace; font-size: "
        f"{font_size}px; font-weight: 600; }} "
        ".label-text {{ font-family: sans-serif; font-size: "
        f"{label_font_size}px; text-anchor: middle; fill: #555; }}"
        "</style>"
    )
    svg_lines.append(
        f'<text x="10" y="20" font-family="sans-serif" font-size="14px" fill="black">Neuron {neuron_id} - top {len(raw_entries)} unique max-prefix patterns</text>'
    )

    y_cursor = 40

    for rank, (best_score, count, tokens) in enumerate(raw_entries, 1):
        ids = jnp.array([[sae_kit.tokenizer.token_to_id(t) for t in tokens]])
        pos = jnp.arange(ids.shape[1])[None, :]
        acts1d = np.asarray(
            sae_kit.get_encoded(ids, pos, layer_id=layer_id)[0, :, neuron_id]
        )
        mask1d = np.asarray(sae_kit.mask_fn(ids)[0])
        acts1d = acts1d * mask1d

        max_i = int(acts1d.argmax())
        half = window // 2
        start = max(0, max_i - half)
        end = min(len(tokens), start + window)

        display_tokens = []
        display_acts = []
        for i in range(start, end):
            display_tokens.append(tokens[i])
            display_acts.append(acts1d[i])
            if tokens[i] == "<eos>":
                break

        tokens_w = display_tokens
        acts_w = np.array(display_acts)
        ptp = np.ptp(acts_w)
        norm_acts = np.zeros_like(acts_w) if ptp == 0 else (acts_w - acts_w.min()) / ptp

        # Start of a new line group
        svg_lines.append(f'<g transform="translate(10, {y_cursor})">')

        # Header text, vertically centered within the content area
        header_y = content_height / 2
        header_text = f'<text x="0" y="{header_y}" dominant-baseline="middle" font-family="sans-serif" font-size="12px">'
        header_text += f'<tspan font-weight="bold">#{rank}</tspan>'
        if count > 1:
            header_text += f'<tspan fill="#888"> ×{count}</tspan>'
        header_text += "</text>"
        svg_lines.append(header_text)

        x_cursor = 60

        for j, (tok, raw_a, norm_a) in enumerate(zip(tokens_w, acts_w, norm_acts)):
            token_width = len(tok) * (font_size * 0.65) + 2 * token_padding

            # Position the token rectangle to be vertically centered in the content area
            rect_y = (content_height - token_height) / 2
            bg_color = _activation_to_rgba(float(norm_a))
            stroke = 'stroke="black" stroke-width="1.5"' if (start + j) == max_i else ""
            svg_lines.append(
                f'<rect x="{x_cursor}" y="{rect_y}" width="{token_width}" height="{token_height}" rx="5" fill="{bg_color}" {stroke} />'
            )

            # Position the token text to be vertically centered
            text_y = content_height / 2
            svg_lines.append(
                f'<text x="{x_cursor + token_width / 2}" y="{text_y}" class="token-text" fill="black" text-anchor="middle" dominant-baseline="middle">{html_escape(tok)}</text>'
            )

            # Activation label on top
            if abs(raw_a) > threshold:
                svg_lines.append(
                    f'<text x="{x_cursor + token_width / 2}" y="{rect_y - 4}" class="label-text">{value_fmt.format(float(raw_a))}</text>'
                )

            x_cursor += token_width + token_margin

        svg_lines.append("</g>")
        # Increment y_cursor by the total line height (content + gap)
        y_cursor += total_line_height

    svg_lines.append("</svg>")

    os.makedirs(output_dir, exist_ok=True)
    file_path = os.path.join(output_dir, f"neuron_{neuron_id}_layer_{layer_id}.svg")
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("".join(svg_lines))

    print(f"Neuron SVG for Neuron {neuron_id} (Layer {layer_id}) saved to {file_path}")


def save_neuron_display(
    neuron_id: int,
    stats: "StatsCollector",
    sae_kit,
    layer_id: int,
    window: int = 8,
    limit: int = 20,
    value_fmt: str = "{:.2f}",
    threshold: float = 6e-3,
    output_dir: str = "neuron_displays",  # New parameter for output directory
):
    raw_entries = stats.top_sequences_for(neuron_id)[:limit]
    if not raw_entries:
        print(f"Neuron {neuron_id} never fired – nothing to show.")
        return

    html_lines: List[str] = []

    # Dark background styling
    html_lines.append("""
<!DOCTYPE html>
<html>
<head>
<style>
  body { background-color: #FFFFFF; color: #1a1a1a; font-family: sans-serif; }
  .tok { position:relative; padding:2px 8px; margin:1px;
         border-radius:8px; font-family:monospace; display:inline-block;
         color: #1a1a1a; /* Dark text for tokens on colored background */
         background-color: rgba(0, 255, 0, 0.15); /* Default green for low activation */
  }
  .tok-label { position:absolute; top:-1.15em; left:50%;
               transform:translateX(-50%); font-size:0.65em; color:#fff;
               pointer-events:none; user-select:none; white-space:nowrap; }
  .line { margin-bottom:8px; }
  .max  { border:2px solid #e0e0e0; /* Light border for max activation */ }
  .count-badge { opacity:.6; color: #a0a0a0; }
  h3 { color: #ffffff; }
</style>
</head>
<body>
""")
    html_lines.append(
        f'<h3 style="color: black;">Neuron {neuron_id} - top {len(raw_entries)} unique max-prefix patterns</h3>'
    )

    for rank, (best_score, count, tokens) in enumerate(raw_entries, 1):
        ids = jnp.array([[sae_kit.tokenizer.token_to_id(t) for t in tokens]])
        pos = jnp.arange(ids.shape[1])[None, :]
        acts1d = np.asarray(
            sae_kit.get_encoded(ids, pos, layer_id=layer_id)[0, :, neuron_id]
        )
        mask1d = np.asarray(sae_kit.mask_fn(ids)[0])
        acts1d = acts1d * mask1d

        max_i = int(acts1d.argmax())
        half = window // 2
        start = max(0, max_i - half)
        end = min(len(tokens), start + window)

        # Stop at first <eos> token
        eos_found = False
        display_tokens = []
        display_acts = []
        for i in range(start, end):
            token = tokens[i]
            if token == "<eos>":
                eos_found = True

            # Ensure we only add the <eos> token itself, then stop
            display_tokens.append(token)
            display_acts.append(acts1d[i])
            if eos_found:
                break

        tokens_w = display_tokens
        acts_w = np.array(display_acts)

        ptp = np.ptp(acts_w)
        norm_acts = np.zeros_like(acts_w) if ptp == 0 else (acts_w - acts_w.min()) / ptp

        # header with optional count badge
        header = f"<b>#{rank}</b>"
        if count > 1:
            header += f" <span class='count-badge'>×{count}</span>"
        line_parts = [f"<div class='line'>{header}&nbsp;"]

        # token chips
        for j, (tok, raw_a, norm_a) in enumerate(zip(tokens_w, acts_w, norm_acts)):
            cls = "tok" + (" max" if (start + j) == max_i else "")
            bg = _activation_to_rgba(float(norm_a))
            esc = html_escape(tok)

            chip = f"<span class='{cls}' style='background:{bg}'>"
            if abs(raw_a) > threshold:
                chip += (
                    f"<span class='tok-label'>{value_fmt.format(float(raw_a))}</span>"
                )
            chip += f"{esc}</span>"
            line_parts.append(chip)

        line_parts.append("</div>")
        html_lines.append("".join(line_parts))

    html_lines.append("</body></html>")  # Close HTML body and document

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Save to file
    file_path = os.path.join(output_dir, f"neuron_{neuron_id}_layer_{layer_id}.html")
    with open(file_path, "w") as f:
        f.write("".join(html_lines))

    print(
        f"Neuron display for Neuron {neuron_id} (Layer {layer_id}) saved to {file_path}"
    )


def display_neuron(
    neuron_id: int,
    stats: "StatsCollector",
    sae_kit,
    layer_id: int,
    window: int = 8,
    limit: int = 20,
    value_fmt: str = "{:.2f}",
    threshold: float = 6e-3,
    cut_eos: bool = False,  # New parameter
):
    raw_entries = stats.top_sequences_for(neuron_id)[:limit]
    if not raw_entries:
        display(HTML(f"<b>Neuron {neuron_id} never fired – nothing to show.</b>"))
        return

    html_lines: List[str] = []
    html_lines.append(
        f"<h3>Neuron {neuron_id} - top {len(raw_entries)} unique max-prefix patterns</h3>"
    )
    html_lines.append("""
<style>
  .tok { position:relative; padding:2px 8px; margin:1px;
         border-radius:8px; font-family:monospace; display:inline-block; }
  .tok-label { position:absolute; top:-1.15em; left:50%;
               transform:translateX(-50%); font-size:0.65em; color:#fff;
               pointer-events:none; user-select:none; white-space:nowrap; }
  .line { margin-bottom:8px; }
  .max  { border:2px solid black; }
  .count-badge { opacity:.6; }
</style>""")

    for rank, (best_score, count, tokens) in enumerate(raw_entries, 1):
        ids = jnp.array([[sae_kit.tokenizer.token_to_id(t) for t in tokens]])
        pos = jnp.arange(ids.shape[1])[None, :]
        acts1d = np.asarray(
            sae_kit.get_encoded(ids, pos, layer_id=layer_id)[0, :, neuron_id]
        )
        mask1d = np.asarray(sae_kit.mask_fn(ids)[0])
        acts1d = acts1d * mask1d

        max_i = int(acts1d.argmax())
        half = window // 2
        start = max(0, max_i - half)
        end = min(len(tokens), start + window)

        # Apply cut_eos logic here
        effective_tokens = []
        effective_acts = []
        eos_found = False
        for i in range(start, end):
            token = tokens[i]
            effective_tokens.append(token)
            effective_acts.append(acts1d[i])
            if cut_eos and token == "<eos>":
                eos_found = True
                break  # Stop adding tokens after <eos> if cut_eos is true

        tokens_w = effective_tokens
        acts_w = np.array(effective_acts)

        ptp = np.ptp(acts_w)
        norm_acts = np.zeros_like(acts_w) if ptp == 0 else (acts_w - acts_w.min()) / ptp

        # header with optional count badge
        header = f"<b>#{rank}</b>"
        if count > 1:
            header += f" <span class='count-badge'>×{count}</span>"
        line_parts = [f"<div class='line'>{header}&nbsp;"]

        # token chips
        for j, (tok, raw_a, norm_a) in enumerate(zip(tokens_w, acts_w, norm_acts)):
            cls = "tok" + (" max" if (start + j) == max_i else "")
            bg = _activation_to_rgba(float(norm_a))
            esc = html_escape(tok)

            chip = f"<span class='{cls}' style='background:{bg}'>"
            if abs(raw_a) > threshold:
                chip += (
                    f"<span class='tok-label'>{value_fmt.format(float(raw_a))}</span>"
                )
            chip += f"{esc}</span>"
            line_parts.append(chip)

        line_parts.append("</div>")
        html_lines.append("".join(line_parts))

    display(HTML("".join(html_lines)))


def save_activating_molecules(
    neuron_id: int,
    stats: "StatsCollector",
    layer_id: int,
    limit: int = 20,
    output_dir: str = "activating_molecules",
):
    """
    Fetches the top activating sequences for a neuron and saves them as PNG molecule images.

    Args:
        neuron_id: The ID of the neuron to analyze.
        stats: The StatsCollector object containing the activation data.
        layer_id: The layer ID where the neuron resides.
        limit: The number of top activating sequences to save.
        output_dir: The root directory to save the output images. A subdirectory
                    for the specific neuron will be created inside this.
    """
    # 1. Fetch the top activating sequences for the neuron
    raw_entries = stats.top_sequences_for(neuron_id)[:limit]

    # Handle the case where the neuron never activated
    if not raw_entries:
        print(
            f"Neuron {neuron_id} (Layer {layer_id}) never fired – no sequences to draw."
        )
        return

    # Create a specific subdirectory for this neuron's images
    neuron_specific_dir = os.path.join(
        output_dir, f"neuron_{neuron_id}_layer_{layer_id}"
    )
    os.makedirs(neuron_specific_dir, exist_ok=True)

    print(
        f"Drawing top {len(raw_entries)} activating sequences for Neuron {neuron_id}..."
    )

    # 2. Iterate through each sequence and draw the molecule
    for rank, (best_score, count, tokens) in enumerate(raw_entries, 1):
        # 3. Generate a descriptive filename
        # Format score to 4 decimal places for readability
        filename_base = f"rank_{rank:02d}_score_{best_score:.4f}"

        # 4. Call the drawing function to generate and save the PNG
        draw_smiles_to_png(
            tokens=tokens, filename=filename_base, output_dir=neuron_specific_dir
        )

    print(f"\nFinished. Images saved in: {neuron_specific_dir}")


def draw_smiles_to_png(
    tokens: List[str],
    filename: str,
    output_dir: str = "molecule_images",
    mol_size: tuple = (400, 400),
    kekulize: bool = True,
    wedge_bonds: bool = True,
    add_hs: bool = False,
    draw_options: Optional[rdMolDraw2D.MolDrawOptions] = None,
) -> Optional[str]:
    """
    Draws a molecule from a list of tokens (SMILES string) and saves it as a PNG.
    """
    filtered_tokens = [tok for tok in tokens if tok not in ["<bos>"]]
    smiles_string = "".join(filtered_tokens).split("<eos>")[0]

    if not smiles_string:
        print(
            f"Warning: No valid SMILES string found for '{filename}' after filtering tokens."
        )
        return None

    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None:
            print(
                f"Error: Could not parse SMILES '{smiles_string}' for file '{filename}'"
            )
            return None

        # --- FIX IS HERE ---
        # Kekulization is performed on the molecule object itself before drawing.
        if kekulize:
            try:
                Chem.Kekulize(mol)
            except:
                # We can ignore KekulizeException for molecules that can't be kekulized
                pass
        # -----------------

        if add_hs:
            mol = Chem.AddHs(mol)

        AllChem.Compute2DCoords(mol)

        drawer = rdMolDraw2D.MolDraw2DCairo(*mol_size)
        if draw_options:
            drawer.SetDrawOptions(draw_options)
        else:
            opts = drawer.drawOptions()
            opts.clearBackground = True
            opts.bondLineWidth = 2
            opts.fillHighlights = False
            opts.addStereoAnnotation = True

        drawer.DrawMolecule(mol)
        drawer.FinishDrawing()

        os.makedirs(output_dir, exist_ok=True)
        file_path = os.path.join(output_dir, f"{filename}.png")
        drawer.WriteDrawingText(file_path)

        return file_path

    except Exception as e:
        print(f"An error occurred while drawing SMILES '{smiles_string}': {e}")
        return None


def plot_selectivity(collectors, metric, save_to=None, x_label="Selectivity"):
    plt.figure(figsize=(7, 5))
    for layer_id, collector_obj in collectors.items():
        scores = metric(collector_obj)
        sns.kdeplot(scores, label=f"L{layer_id}", fill=True, alpha=0.2, linewidth=1.5)

    plt.title("Kernel Density Estimate of Neuron Selectivity")
    plt.xlabel(x_label)
    plt.ylabel("Density")
    plt.legend(title="Layer ID", loc="upper right")
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()

    if save_to:
        plt.savefig(save_to, bbox_inches="tight")
    else:
        plt.show()


def umap_selectivity(matrix, collector, metric, random_state=2002, save_to=None):
    sel_scores = metric(collector)

    umap_model = umap.UMAP(
        n_neighbors=15,
        min_dist=0.1,
        n_components=2,
        metric="cosine",
        random_state=random_state,
        verbose=False,
    )
    coords_umap = umap_model.fit_transform(matrix)
    print(f"UMAP embedding computed with shape: {coords_umap.shape}")

    plt.rcParams.update({"font.size": 12})

    fig, ax = plt.subplots(figsize=(7, 6))

    scatter_im = ax.scatter(
        coords_umap[:, 0],
        coords_umap[:, 1],
        c=sel_scores,
        cmap="magma",
        s=20,
        alpha=0.8,
        vmin=0.0,
        vmax=1.0,
    )

    ax.set_xlabel("UMAP-1", fontsize=12)
    ax.set_ylabel("UMAP-2", fontsize=12)
    ax.set_title(r"UMAP of SAE Decoder Weights ($W_{dec}$)", pad=15, fontsize=14)
    ax.tick_params(direction="out", labelsize=10)
    ax.grid(True, linestyle="--", alpha=0.5)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.15)
    cb = fig.colorbar(scatter_im, cax=cax)
    cb.set_label(
        r"Token Selectivity Score ($S_d$)", rotation=270, labelpad=20, fontsize=12
    )
    cb.ax.tick_params(direction="out", labelsize=10)

    fig.tight_layout()
    if save_to:
        plt.savefig(save_to, bbox_inches="tight")
    else:
        plt.show()


def display_clusters(coords_umap, labels):
    plt.rcParams.update({"font.size": 12})
    fig, ax = plt.subplots(figsize=(7, 6))
    scatter = ax.scatter(
        coords_umap[:, 0], coords_umap[:, 1], c=labels, cmap="tab20", s=20, alpha=0.3
    )
    ax.set_xlabel("UMAP-1")
    ax.set_ylabel("UMAP-2")
    ax.set_title(r"UMAP of $W_{dec}$ coloured by HDBSCAN clusters")
    divider = make_axes_locatable(ax)
    cb = fig.colorbar(scatter, cax=divider.append_axes("right", size="5%", pad=0.15))
    cb.set_label("Cluster ID (-1 = noise)")
    fig.tight_layout()
    plt.show()
