import numpy as np
import matplotlib
from IPython.display import display, HTML
import torch
import html

import numpy as np
def format_and_print_cluster(
    result, output_result, model, cluster_id, 
    gap_symbol="_", collapse_symbol="[...]", 
    print_input=True, print_output=True
):
    """
    Prints patterns and output patterns for a given cluster.
    - print_input=True: highlights maximally activating input token with <<< >>>.
    - print_output=True: shows output tokens from output_result corresponding to each sample.
    - Zeros in output_patterns are treated as gaps (_ or [...]).
    """
    for idx, (pat, out, label) in enumerate(zip(result['patterns'], result['outputs'], result['labels'])):
        if label != cluster_id:
            continue

        tokens = []

        # Corresponding output pattern for this sample
        out_pat = output_result['output_patterns'][idx] if 'output_patterns' in output_result else None

        # --- Case: input visible ---
        if print_input:
            pat_arr = np.atleast_1d(np.asarray(pat))
            gap_run = 0

            # Find last nonzero index (maximally activating token)
            nonzero_indices = np.nonzero(pat_arr)[0]
            max_idx = nonzero_indices[-1] if len(nonzero_indices) > 0 else -1

            for i, val in enumerate(pat_arr):
                if val == 0:
                    gap_run += 1
                else:
                    if gap_run > 0:
                        if gap_run > 3:
                            tokens.append(collapse_symbol)
                        else:
                            tokens.extend([gap_symbol] * gap_run)
                        gap_run = 0

                    decoded = model.tokenizer.decode([int(val)]).replace("\n", "\\n").replace("\t", "\\t")

                    if i == max_idx:
                        tokens.append(f"<<<{decoded if decoded else '<unk>'}>>>")
                        # append generated tokens from output_pattern
                        if print_output and out_pat:
                            for tok in out_pat:
                                if tok == 0:
                                    # treat as gap
                                    tokens.append(gap_symbol)
                                else:
                                    decoded_tok = model.tokenizer.decode([int(tok)]).replace("\n", "\\n").replace("\t", "\\t")
                                    tokens.append(decoded_tok)
                    else:
                        tokens.append(decoded if decoded else "<unk>")

            # trailing gaps
            if gap_run > 0:
                tokens.append(collapse_symbol if gap_run > 3 else gap_symbol * gap_run)

        # --- Case: no input, only output ---
        elif print_output and out_pat:
            for idx_tok, tok in enumerate(out_pat):
                if tok == 0:
                    tokens.append(gap_symbol)
                else:
                    decoded_tok = model.tokenizer.decode([int(tok)]).replace("\n", "\\n").replace("\t", "\\t")
                    # highlight first generated token
                    if idx_tok == 0:
                        tokens.append(f"<<<{decoded_tok}>>>")
                    else:
                        tokens.append(decoded_tok)

        print(" ".join(tokens))




def safe_token_str(s: str) -> str:
    """Render special whitespace/control tokens visibly for HTML output."""
    if s == "\n":
        return "\\n"
    if s == "\t":
        return "\\t"
    if s == "\r":
        return "\\r"
    if s.strip() == "":
        # whitespace-only tokens like " " or zero-width space
        return repr(s)[1:-1]  # ' ' or '\u200b'
    return html.escape(s)

def visualize_results_with_outputs(
    model, 
    result, 
    cluster_id=None,       # optional: filter by cluster
    cmap_name="Greens", 
    window=10
):
    """
    Visualize token activations per input with optional cluster filtering,
    and show generated outputs after most activated token.
    
    result: dict with keys:
        - 'inputs': list of (token_tensor, max_token_pos)
        - 'activations': list of tensor activations [seq_len, d_mlp] or [seq_len]
        - 'outputs': list of dicts with 'greedy_continuation' and 'top_predictions'
        - 'labels': cluster labels (optional)
    """
    cmap = matplotlib.cm.get_cmap(cmap_name)
    html = "<div style='font-family: monospace; white-space: pre-wrap; line-height:1.4em;'>"
    
    cluster_colors = ["#ff9999", "#99ff99", "#9999ff", "#ffcc99", "#cc99ff"]
    outlier_color = "#f0f0f0"
    
    n_samples = len(result['inputs'])
    
    for i in range(n_samples):
        if cluster_id is not None and 'labels' in result and result['labels'][i] != cluster_id:
            continue
        
        input_tokens, max_token_pos = result['inputs'][i]
        input_tokens = input_tokens[:max_token_pos+1]
        decoded_tokens = [model.tokenizer.decode([int(t)]).strip() for t in input_tokens]
        
        acts = result['activations'][i]
        acts = acts if isinstance(acts, np.ndarray) or isinstance(acts, list) else acts.detach().cpu().numpy()
        acts = np.array(acts) if isinstance(acts, list) else acts
        
        # Compute global min/max for normalization
        global_min, global_max = acts.min(), acts.max()
        
        # Find most activated token
        max_i = int(np.argmax(acts))
        max_val = float(acts[max_i])
        start, end = max(0, max_i-window), min(len(input_tokens), max_i+window+1)
        
        # Cluster color
        if 'labels' in result:
            cluster_idx = result['labels'][i]
            header_color = outlier_color if cluster_idx == -1 else cluster_colors[cluster_idx % len(cluster_colors)]
        else:
            header_color = "#66ccff"
        
        html += f"<div><span style='font-weight:bold; color:{header_color};'>S{i} [max:{max_val:.2f}]</span> "
        
        # Original tokens with activation coloring
        for j in range(start, end):
            norm_act = (acts[j] - global_min) / (global_max - global_min + 1e-8)
            norm_act *= 0.7
            rgba = cmap(norm_act)
            hex_color = matplotlib.colors.rgb2hex(rgba)
            border = "2px solid red" if j == max_i else "1px solid transparent"
            
            html += (
                f"<span style='background:{hex_color}; color:black; font-weight:normal; "
                f"padding:1px 3px; margin:0 1px; border-radius:3px; border:{border};'>"
                f"{decoded_tokens[j] if decoded_tokens[j] else '<unk>'}"
                f"</span>"
            )

        
        # Generated tokens appended after most activated token
        if 'outputs' in result:
            gen_tokens = result['outputs'][i].get('greedy_continuation', [])
            if gen_tokens:
                html += " "
                for idx, t in enumerate(gen_tokens):
                    decoded = model.tokenizer.decode([int(t)]).strip()
                    html += (
                        f"<span style='background:#ffff99;color:black; padding:1px 3px; margin:0 1px; border-radius:3px; border:1px solid orange;font-weight:normal;'>"
                        f"{decoded if decoded else '<unk>'}"
                        f"</span>"
                    )
        
        html += "</div>"
    
    html += "</div>"
    display(HTML(html))