# full_grid_probe.py
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from transformer_lens import HookedTransformer
# This is the standard Hugging Face library for loading models
from transformers import AutoModelForCausalLM
from intervention.tictactoe_task import TicTacToeTask
from feature_discovery import run_feature_discovery, sanitize_model_name
import os
import random
from typing import Optional
import seaborn as sns
from sklearn.cluster import KMeans, DBSCAN
from scipy.spatial import ConvexHull
from adjustText import adjust_text
from collections import Counter
from collections import defaultdict
from scipy import stats
from typing import Dict, Any, List, Tuple
import json
from functools import lru_cache
import math
import argparse

"""
# 1. Rebuild / reopen interactive tool later
python generate_interactive_merge_tool.py \
  --embeddings tsne_cache/your_model_layer8_text_instruction.npz \
  --subcluster-meta visualizations/YourModel/layer_8/hybrid_merge_subcluster_metadata_text_instruction_layer8.json \
  --output-dir manual_merge_html \
  --layer 8 \
  --prompt-style text_instruction

# Open the generated HTML, perform merges, download merged_subclusters.json

# 2. Apply merges
python apply_manual_merges.py \
  --embeddings tsne_cache/Qwen-Finetuned-GRPO__Best_Move_Checkpoint_Trained_from_Base__layer_12_text_instruction.npz \
  --subcluster-meta manual_merge_html/hybrid_merge_subcluster_metadata_text_instruction_layer12.json \
  --merges-json merged_subclusters.json \
  --output-dir merged_outputs/layer12_text \
  --layer 12 \
  --prompt-style text_instruction \
  --save-csv
"""

# Invariance mode globals (set in main via CLI)
INVARIANCE_ACTIVE = False
INVARIANCE_TOKEN_PAIR = None  # Tuple like ('P','Q') or ('A','B')

VIZ_ROOT_DEFAULT = "visualizations_iclr__invariance_appendix"
VIZ_ROOT_INVARIANCE = "visualizations_invariance_iclr_appendix"

def get_viz_root():
    return VIZ_ROOT_INVARIANCE if INVARIANCE_ACTIVE else VIZ_ROOT_DEFAULT

# --- Configuration ---
DATASET_FULL_PATH = "/mnt/shared/data/stlm-logic/datasets/tictactoe_dataset.json"
ILLEGAL_DATASET_FULL_PATH = "/mnt/shared/data/stlm-logic/datasets/illegal_boards_x_DOUBLE_WIN+COUNT_DIFF_GT1_any_5000_20250906_152524.json"
MODELS_TO_PROBE = [
    {
        "display_name": "Qwen/Qwen2.5-1.5B-Instruct",
        "load_path": "Qwen/Qwen2.5-1.5B-Instruct",
        "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
        "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen-Qwen2.5-1.5B-Instruct_results.json"
    },
    # {
    #     "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 300",
    #     "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-300",
    #     "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
    #     # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
    #     # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    # },
    {
        "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 600",
        "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-600",
        "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
        # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
        # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    },
    # {
    #     "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 900",
    #     "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-900",
    #     "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
    #     # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
    #     # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    # },
    {
        "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 1200",
        "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-1200",
        "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
        # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
        # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    },
    # {
    #     "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 1500",
    #     "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-1500",
    #     "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
    #     # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
    #     # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    # },
    # {
    #     "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base) Ckpt 2250, Overtrained",
    #     "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-2250",
    #     "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
    #     # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
    #     # "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    # },
    {
        "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Base)",
        "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move/checkpoint-1800",
        "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
        # Optional param which we can provide if we want to visualize the board states labeled by board states where the best move was correctly computed
        "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping_best_move-checkpoint-1800_results.json"
    },
    # {
    #     "display_name": "Qwen-Finetuned-GRPO (Best Move Checkpoint Trained from Legal Move)",
    #     "load_path": "/mnt/shared/data/stlm-logic/grpo_v2_expt_lr_1e-6_best_move_from_legal/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping/checkpoint-1800",
    #     "architecture_name": "Qwen/Qwen2.5-1.5B-Instruct",
    #     "full_dataset_inference_path_best_move": "/mnt/shared/stlm-logic/results/full_dataset/best_move/random_xy_moves_False/Qwen_Qwen2.5-1.5B-Instruct_nl_canconical-symmetry-grouping-checkpoint-1800_results.json"
    # }
    
]
# LAYERS_TO_PROBE = [12, 16, 20, 24]
# LAYERS_TO_PROBE = [4, 8, 13, 14, 15, 17, 18]
LAYERS_TO_PROBE = [4, 8, 12, 13, 14, 15, 16, 17, 18, 20, 24]
# LAYERS_TO_PROBE = [12]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VISUALIZATION_BATCH_SIZE = 16

# Helper function to get the inference path from model info
def get_inference_path(model_info: Dict[str, Any]) -> Optional[str]:
    return model_info.get("full_dataset_inference_path_best_move")

def load_inference_results(path: str) -> Dict[str, Any]:
    """Loads the inference JSON file once and returns the relevant data."""
    if not path or not os.path.exists(path):
        print(f"Warning: Inference results file not found at {path}")
        return {}
    try:
        with open(path, 'r') as f:
            data = json.load(f)
            # The actual results are nested under the "individual_outputs" key
            return data.get("individual_outputs", {})
    except (json.JSONDecodeError, IOError) as e:
        print(f"Error loading or parsing inference results from {path}: {e}")
        return {}

def get_model_prediction_correctness(inference_data: Dict[str, Any], board_state: List[int]) -> bool:
    """
    Checks if the model's prediction for a given board state was correct
    using the pre-loaded inference data.
    """
    if not inference_data:
        return False

    # The board state is used as a key in the JSON file, so convert it to a string
    board_key = str(board_state)

    # Look up the board state in the inference data
    board_results = inference_data.get(board_key)
    if not board_results:
        print(f"Board state {board_key} not found in inference data.")
        return False

    # The result for each board is a list; we'll check the first entry
    if isinstance(board_results, list) and len(board_results) > 0:
        return board_results[0].get("is_correct", False)

    return False

# --- Minimax and Game Logic Helpers (for Strategic Analysis) ---
# These functions are adapted from your dataset generation script to analyze game states on the fly.

def _minimax_check_winner(board: List[int]) -> tuple[int, bool]:
    """
    Checks for a winner on the board.
    Returns: (winner (0=draw/none, 1=P1, 2=P2), is_terminal)
    """
    lines = [(0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8), (2,4,6)]
    for a, b, c in lines:
        if board[a] != 0 and board[a] == board[b] == board[c]:
            return board[a], True
    if 0 not in board:
        return 0, True
    return 0, False

def _minimax_get_score(board: List[int], player: int, depth: int = 0) -> tuple[float, Optional[int]]:
    """
    Minimax that returns the game-theoretic score for a position.
    Prefers faster wins and slower losses.
    Returns: (score, best_move_index)
    """
    winner, is_terminal = _minimax_check_winner(board)
    if is_terminal:
        if winner == 1: return 10 - depth, None
        if winner == 2: return -10 + depth, None
        return 0, None

    empty_cells = [i for i, cell in enumerate(board) if cell == 0]
    
    if player == 1:  # Maximizing
        max_eval = -math.inf
        best_move = None
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 1
            evaluation, _ = _minimax_get_score(new_board, 2, depth + 1)
            if evaluation > max_eval:
                max_eval = evaluation
                best_move = move
        return max_eval, best_move
    else:  # Minimizing
        min_eval = math.inf
        best_move = None
        for move in empty_cells:
            new_board = board[:]
            new_board[move] = 2
            evaluation, _ = _minimax_get_score(new_board, 1, depth + 1)
            if evaluation < min_eval:
                min_eval = evaluation
                best_move = move
        return min_eval, best_move
    

class TicTacToeClusterHypothesisTester:
    """
    Analyzes clustered Tic-Tac-Toe boards to test various hypotheses.
    """
    def __init__(self, file_path_or_json_string: str):
        if file_path_or_json_string.endswith('.json'):
            self.file_path = file_path_or_json_string
            self.clusters = self._load_clusters_from_file()
        else:
            self.file_path = "in-memory-data"
            self.clusters = self._load_clusters_from_string(file_path_or_json_string)

        self.LINES = [[0,1,2],[3,4,5],[6,7,8],[0,3,6],[1,4,7],[2,5,8],[0,4,8],[2,4,6]]
        self.LINE_NAMES = {0:"R0", 1:"R1", 2:"R2", 3:"C0", 4:"C1", 5:"C2", 6:"D1", 7:"D2"}
        self.PIECE_MAP = {1: 'X', -1: 'O', 0: '.'} # Added here for the recursive function to use
        
        if self.clusters:
            self.parsed_clusters = self._parse_all_clusters()
            self.analysis_results = self._analyze_clusters()
        else:
            self.parsed_clusters, self.analysis_results = {}, {}

    def _load_clusters_from_file(self) -> Dict[str, Any]:
        try:
            with open(self.file_path, 'r') as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error loading cluster file {self.file_path}: {e}")
            return {}
            
    def _load_clusters_from_string(self, json_string: str) -> Dict[str, Any]:
        try:
            return json.loads(json_string)
        except json.JSONDecodeError as e:
            print(f"Error loading cluster data from string: {e}")
            return {}

    # Deprecated
    def _load_clusters(self) -> Dict[str, Any]:
        try:
            with open(self.file_path, 'r') as f:
                return json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error loading cluster file {self.file_path}: {e}")
            return {}

    def _parse_board(self, board_str: str) -> Optional[np.ndarray]:
        board = np.zeros((3, 3), dtype=int)
        try:
            if "Row" in board_str:
                parts = board_str.replace('.', '').split('Row ')
                for part in parts:
                    if not part.strip(): continue
                    row_index = int(part.split(': ')[0])
                    cells = [cell.strip() for cell in part.split(': ')[1].split(',')]
                    for col_index, cell in enumerate(cells):
                        if cell == 'X': board[row_index, col_index] = 1
                        elif cell == 'O': board[row_index, col_index] = -1
            else:
                for r, line in enumerate(board_str.strip().split('\n')):
                    cells = [cell for cell in line.split(' ') if cell]
                    for c, piece in enumerate(cells):
                        if piece == 'X': board[r, c] = 1
                        elif piece == 'O': board[r, c] = -1
            return board
        except (ValueError, IndexError):
            return None

    def _parse_all_clusters(self) -> Dict[str, list]:
        parsed = defaultdict(list)
        for cluster_id, boards in self.clusters.items():
            for b_str in boards:
                parsed_board = self._parse_board(b_str)
                if parsed_board is not None:
                    parsed[cluster_id].append(parsed_board)
        return parsed

    # --- Feature Extraction Methods ---
    def get_piece_count(self, board): return np.count_nonzero(board)
    def get_open_squares(self, board): return 9 - self.get_piece_count(board)
    def get_center_piece(self, board): return board[1, 1]
    def check_win(self, board):
        board_flat = board.flatten()
        for line in self.LINES:
            if abs(sum(board_flat[line])) == 3: return board_flat[line[0]]
        return 0

    def get_canonical_form(self, board: np.ndarray) -> tuple:
        """
        Calculates the canonical representation of a board by applying all 8 symmetries
        (4 rotations, 4 rotations + reflection) and choosing the lexicographically smallest.
        This allows us to treat strategically identical boards as the same.
        """
        symmetries = []
        current_board = board.copy()
        for _ in range(4): # 4 rotations
            symmetries.append(tuple(current_board.flatten()))
            symmetries.append(tuple(np.fliplr(current_board).flatten())) # Flipped
            current_board = np.rot90(current_board)
        return min(symmetries)

    def get_line_threat_count(self, board: np.ndarray, player: int) -> int:
        threats = 0
        board_flat = board.flatten()
        for line in self.LINES:
            pieces = board_flat[line]
            if np.sum(pieces) == 2 * player and 0 in pieces:
                threats += 1
        return threats
    def has_fork(self, board: np.ndarray, player: int) -> bool:
        open_squares = np.where(board == 0)
        for r, c in zip(*open_squares):
            temp_board = board.copy()
            temp_board[r, c] = player
            if self.get_line_threat_count(temp_board, player) >= 2:
                return True
        return False
    def get_player_turn(self, board: np.ndarray) -> int:
        return 1 if np.count_nonzero(board == 1) == np.count_nonzero(board == -1) else -1
    def get_corner_edge_counts(self, board: np.ndarray, player: int):
        corners = [board[0,0], board[0,2], board[2,0], board[2,2]]
        edges = [board[0,1], board[1,0], board[1,2], board[2,1]]
        return corners.count(player), edges.count(player)

    def _analyze_clusters(self) -> Dict[str, Any]:
        """Analyzes clusters, now with top 10 line patterns and metric normalization."""
        results = {}
        for cid, boards in self.parsed_clusters.items():
            if not boards: continue
            features = []
            for b in boards:
                turn = self.get_player_turn(b)
                x_corners, x_edges = self.get_corner_edge_counts(b, 1)
                features.append({
                    'piece_count': self.get_piece_count(b), 'center_piece': self.get_center_piece(b),
                    'open_squares': self.get_open_squares(b), 'canonical_form': self.get_canonical_form(b),
                    'winner': self.check_win(b), 'x_threats': self.get_line_threat_count(b, 1),
                    'fork_opportunity': self.has_fork(b, turn), 'x_corners': x_corners, 'x_edges': x_edges,
                })

            num_boards = len(boards)
            
            # Calculate Hamming distances
            
            distances = []
            for i in range(num_boards):
                for j in range(i + 1, num_boards):
                    distances.append(np.sum(boards[i] != boards[j]))
            avg_hamming = np.mean(distances) if distances else 0
            
            # Find top 10 line patterns
            all_line_patterns = []
            board_flats = [b.flatten() for b in boards]
            for bf in board_flats:
                for i, line_indices in enumerate(self.LINES):
                    all_line_patterns.append((i, tuple(bf[line_indices])))
            
            # *** CHANGE 5: Get top 10 patterns instead of 5 ***
            top_10_patterns = Counter(all_line_patterns).most_common(10) if all_line_patterns else []
            max_purity = top_10_patterns[0][1] / num_boards if top_10_patterns else 0

            results[cid] = {
                'size': num_boards,
                'avg_piece_count': np.mean([f['piece_count'] for f in features]),
                'center_x_prop': [f['center_piece'] for f in features].count(1) / num_boards,
                'x_win_prop': [f['winner'] for f in features].count(1) / num_boards,
                'o_win_prop': [f['winner'] for f in features].count(-1) / num_boards,
                'canonical_ratio': len(set(f['canonical_form'] for f in features)) / num_boards,
                'avg_x_threats': np.mean([f['x_threats'] for f in features]),
                'fork_prop': [f['fork_opportunity'] for f in features].count(True) / num_boards,
                'avg_x_corners': np.mean([f['x_corners'] for f in features]),
                'avg_x_edges': np.mean([f['x_edges'] for f in features]),
                'avg_hamming_dist': avg_hamming,
                'dominant_line_purity': max_purity,
                'top_10_line_patterns': top_10_patterns,
            }

        # *** CHANGE 3: Normalize all numeric metrics across clusters ***
        numeric_metrics = [k for k, v in next(iter(results.values()), {}).items() if isinstance(v, (int, float))]
        global_stats = {metric: [r[metric] for r in results.values() if metric in r and r[metric] is not None] for metric in numeric_metrics}

        for metric, values in global_stats.items():
            if not values: continue
            min_val, max_val = min(values), max(values)
            for cid in results.keys():
                if metric in results[cid] and results[cid][metric] is not None:
                    # Store raw value for explicit labeling
                    results[cid][f'{metric}_raw'] = results[cid][metric]
                    # Normalize and store
                    if (max_val - min_val) > 1e-6:
                        results[cid][f'{metric}_normalized'] = (results[cid][metric] - min_val) / (max_val - min_val)
                    else:
                        results[cid][f'{metric}_normalized'] = 0.5
        return results


    def run_statistical_tests(self):
        """
        Performs and prints a comprehensive suite of statistical tests comparing all clusters.
        """
        if not self.analysis_results or len(self.analysis_results) < 2:
            print("Not enough valid clusters to run statistical tests.")
            return self.analysis_results

        print(f"\n--- Statistical Hypothesis Tests for {os.path.basename(self.file_path)} ---")

        def run_anova(feature_func, hypothesis_name):
            data = [[feature_func(board) for board in self.parsed_clusters[cid]] for cid in self.analysis_results.keys()]
            if not all(d for d in data): return
            try:
                f, p = stats.f_oneway(*data)
                if np.isnan(p): print(f"{hypothesis_name}: ANOVA resulted in NaN (likely due to no variance).")
                else: print(f"{hypothesis_name}: ANOVA F={f:.2f}, p={p:.4f} -> {'Significant' if p < 0.05 else 'Not Significant'}")
            except ValueError as e: print(f"{hypothesis_name}: ANOVA could not be computed. Reason: {e}")
        
        def run_chi2(feature_func, hypothesis_name):
            table = [[Counter(feature_func(b) for b in self.parsed_clusters[cid])[val] for val in [1, -1, 0]] for cid in self.analysis_results.keys()]
            if len(table) > 1 and np.sum(table) > 0:
                try:
                    chi2, p, _, _ = stats.chi2_contingency(table)
                    print(f"{hypothesis_name}: Chi-squared={chi2:.2f}, p={p:.4f} -> {'Dependent' if p < 0.05 else 'Independent'}")
                except ValueError as e: print(f"{hypothesis_name}: Chi-squared test failed. Reason: {e}")

        run_anova(self.get_piece_count, 'H1 (Game Stage / Piece Count)')
        run_anova(self.get_open_squares, 'H10 (Game Openness / Open Squares)')
        run_chi2(self.get_center_piece, 'H2 (Center Control)')
        run_chi2(self.check_win, 'H4 (Game Outcome)')
        run_anova(lambda b: self.get_corner_edge_counts(b, 1)[0], 'H8a (X Corner Control)')
        run_anova(lambda b: self.get_corner_edge_counts(b, 1)[1], 'H8b (X Edge Control)')
        run_anova(lambda b: self.get_line_threat_count(b, 1), 'H11 (X Threat Count)')
        
        fork_table = [[sum(1 for b in self.parsed_clusters[cid] if self.has_fork(b, self.get_player_turn(b))), self.analysis_results[cid]['size'] - sum(1 for b in self.parsed_clusters[cid] if self.has_fork(b, self.get_player_turn(b)))] for cid in self.analysis_results.keys()]
        if len(fork_table) > 1 and np.sum(fork_table) > 0:
            try:
                chi2, p, _, _ = stats.chi2_contingency(fork_table)
                print(f"H9 (Fork Opportunity): Chi-squared={chi2:.2f}, p={p:.4f} -> {'Dependent' if p < 0.05 else 'Independent'}")
            except ValueError as e: print(f"H9 (Fork Opportunity): Chi-squared test failed. Reason: {e}")

        print("\n--- Descriptive Statistics (per cluster) ---")
        for cid, r in self.analysis_results.items():
            print(f"Cluster {cid}: Canonical Ratio={r['canonical_ratio']:.2f}, Avg Hamming Dist={r['avg_hamming_dist']:.2f}, Line Purity={r['dominant_line_purity']:.2f}")

        print("\n--- End of Statistical Tests ---\n")
        return self.analysis_results

class NpEncoder(json.JSONEncoder):
    """ Custom JSON encoder for NumPy types. """
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)


def save_analysis_results(analysis_results, output_dir, layer, prompt_style, algorithm_name):
    """Saves the detailed (and now normalized) cluster analysis results to a JSON file."""
    output_path = os.path.join(output_dir, f'{algorithm_name}_analysis_results_{prompt_style}.json')
    serializable_results = {}
    for cid, data in analysis_results.items():
        serializable_results[cid] = {}
        for key, value in data.items():
            if key == 'top_10_line_patterns' and value is not None:
                serializable_results[cid][key] = [
                    {'line_index': p[0][0], 'pattern': list(p[0][1]), 'count': int(p[1])} for p in value
                ]
            else:
                serializable_results[cid][key] = value
    
    try:
        with open(output_path, 'w') as f:
            json.dump(serializable_results, f, indent=4, cls=NpEncoder)
        print(f"Successfully saved analysis results to {output_path}")
    except Exception as e:
        print(f"Error saving analysis results to {output_path}: {e}")

def visualize_cluster_extremes_normalized(reduced_activations, labels, analysis_results, output_dir, layer, prompt_style, algorithm_name):
    """
    Visualizes clusters and annotates the ones that are global leaders for each NORMALIZED metric.
    """
    print(f"Visualizing NORMALIZED global cluster extremes for {algorithm_name}...")
    if not analysis_results or len(analysis_results) < 2: return

    n_clusters = len(analysis_results)
    fig, ax = plt.subplots(figsize=(22, 18))
    cmap = plt.get_cmap('tab20b', n_clusters)
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap=cmap, alpha=0.3, s=20)

    centroids = {
        cid: np.mean(reduced_activations[labels == int(cid)], axis=0)
        for cid in analysis_results.keys() if len(reduced_activations[labels == int(cid)]) > 0
    }
    
    # Identify NORMALIZED metrics to plot
    metrics_to_plot = [key for key in next(iter(analysis_results.values())) if key.endswith('_normalized')]
    
    texts = []
    annotated_centroids = set()

    for metric_normalized in metrics_to_plot:
        metric_name_clean = metric_normalized.replace('_normalized', '').replace('_', ' ').replace('prop', 'Rate').replace('avg', '').title()
        
        leader_cid = max(analysis_results, key=lambda cid: analysis_results[cid].get(metric_normalized, -1))
        
        if leader_cid in centroids:
            leader_centroid_tuple = tuple(centroids[leader_cid])
            if leader_centroid_tuple in annotated_centroids: continue
            annotated_centroids.add(leader_centroid_tuple)

            # Display the RAW value for interpretability, even though we sort by normalized
            metric_raw = metric_normalized.replace('_normalized', '_raw')
            value = analysis_results[leader_cid].get(metric_raw, 'N/A')
            
            label_text = f"Highest {metric_name_clean}\n"
            if isinstance(value, float):
                 label_text += f"({value:.1%})" if "Rate" in label_text or "Purity" in label_text else f"({value:.2f})"
            else:
                 label_text += f"({value})"
            
            texts.append(ax.text(centroids[leader_cid][0], centroids[leader_cid][1], label_text, ha='center', va='center',
                                  bbox=dict(boxstyle="round,pad=0.5", fc="ivory", ec="black", lw=1, alpha=0.9),
                                  fontsize=11, weight='bold'))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", color='black', lw=1.5))

    ax.set_title(f'{algorithm_name} Analysis: Global Metric Leaders (Normalized) (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    filename = f'viz_{algorithm_name}_global_extremes_normalized_{prompt_style}.png'
    plt.savefig(os.path.join(output_dir, filename)); plt.close(fig)
    print(f"Normalized global extremes visualization saved to {os.path.join(output_dir, filename)}")


def visualize_kmeans_hypothesis_normalized(reduced_activations, kmeans_results, analysis_results, output_dir, layer, prompt_style):
    """
    Visualizes K-Means clusters, annotating based on the most extreme NORMALIZED metric for each cluster.
    """
    print("Visualizing K-Means hypothesis summary plot (Normalized)...")
    if not analysis_results or len(analysis_results) < 2: return

    n_clusters = kmeans_results.n_clusters
    labels = kmeans_results.labels_

    fig, ax = plt.subplots(figsize=(22, 18))
    cmap = plt.get_cmap('tab20b', n_clusters)
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap=cmap, alpha=0.4, s=20)

    cluster_defining_features = {}
    for cid, results in analysis_results.items():
        max_z_equiv = -1
        defining_feature = None
        # Use the normalized value directly as a proxy for outlier status (0 to 1 scale)
        for metric, value in results.items():
            if not metric.endswith('_normalized') or value is None: continue
            
            # The most extreme feature is the one closest to 0 or 1
            z_equivalent_score = abs(value - 0.5) 
            if z_equivalent_score > max_z_equiv:
                max_z_equiv = z_equivalent_score
                defining_feature = (metric, value)
        
        if defining_feature:
            cluster_defining_features[cid] = defining_feature

    num_annotations = 18
    sorted_clusters = sorted(cluster_defining_features.items(), key=lambda item: abs(item[1][1] - 0.5), reverse=True)
    
    texts = []
    for cid, (metric_norm, norm_value) in sorted_clusters[:num_annotations]:
        direction = "Highest" if norm_value > 0.5 else "Lowest"
        metric_name = metric_norm.replace('_normalized', '').replace('_', ' ').replace('prop', 'Rate').replace('avg', '').title()
        
        # Get the raw value for the label
        raw_metric_name = metric_norm.replace('_normalized', '_raw')
        raw_value = analysis_results[cid].get(raw_metric_name, "N/A")

        label_text = f"{direction} {metric_name}\n"
        if isinstance(raw_value, float):
            if "Rate" in label_text or "Ratio" in label_text or "Purity" in label_text:
                label_text += f"({raw_value:.1%})"
            else:
                label_text += f"({raw_value:.2f})"
        else:
            label_text += f"({raw_value})"

        cluster_points = reduced_activations[labels == int(cid)]
        if len(cluster_points) > 0:
            centroid = np.mean(cluster_points, axis=0)
            texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                                 bbox=dict(boxstyle="round,pad=0.5", fc="ivory", ec="black", lw=1, alpha=0.9),
                                 fontsize=11, weight='bold'))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", color='black', lw=1.5))

    ax.set_title(f'K-Means Analysis: Defining Features (Normalized) (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    filename = f'viz_kmeans_hypothesis_summary_normalized_{prompt_style}.png'
    plt.savefig(os.path.join(output_dir, filename)); plt.close(fig)
    print(f"Normalized K-Means hypothesis visualization saved to {os.path.join(output_dir, filename)}")

def visualize_cluster_extremes(reduced_activations, labels, analysis_results, output_dir, layer, prompt_style, algorithm_name):
    """
    Visualizes clusters and annotates the ones that are global leaders for each metric.
    """
    print(f"Visualizing global cluster extremes for {algorithm_name}...")
    if not analysis_results or len(analysis_results) < 2: return

    n_clusters = len(analysis_results)
    fig, ax = plt.subplots(figsize=(22, 18))
    cmap = plt.get_cmap('tab20b', n_clusters)
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap=cmap, alpha=0.3, s=20)

    # Find centroids for annotation
    centroids = {
        cid: np.mean(reduced_activations[labels == int(cid)], axis=0)
        for cid in analysis_results.keys() if len(reduced_activations[labels == int(cid)]) > 0
    }
    
    # Identify metrics to plot (the raw ones)
    metrics_to_plot = [key for key in next(iter(analysis_results.values())) if key.endswith('_raw')]
    
    texts = []
    annotated_centroids = set()

    for metric_raw in metrics_to_plot:
        metric_name_clean = metric_raw.replace('_raw', '').replace('_', ' ').replace('prop', 'Rate').replace('avg', '').title()
        
        # Find the cluster with the highest value for this metric
        leader_cid = max(analysis_results, key=lambda cid: analysis_results[cid].get(metric_raw, -1))
        
        if leader_cid in centroids:
            leader_centroid_tuple = tuple(centroids[leader_cid])
            # Avoid re-annotating the same point if it's a leader in multiple metrics
            if leader_centroid_tuple in annotated_centroids: continue
            annotated_centroids.add(leader_centroid_tuple)

            value = analysis_results[leader_cid][metric_raw]
            label_text = f"Highest {metric_name_clean}\n"
            label_text += f"({value:.1%})" if "Rate" in label_text or "Purity" in label_text else f"({value:.2f})"
            
            texts.append(ax.text(centroids[leader_cid][0], centroids[leader_cid][1], label_text, ha='center', va='center',
                                  bbox=dict(boxstyle="round,pad=0.5", fc="ivory", ec="black", lw=1, alpha=0.9),
                                  fontsize=11, weight='bold'))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", color='black', lw=1.5))

    ax.set_title(f'{algorithm_name} Cluster Analysis: Global Metric Leaders (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    filename = f'viz_{algorithm_name}_global_extremes_{prompt_style}.png'
    plt.savefig(os.path.join(output_dir, filename)); plt.close(fig)
    print(f"Global extremes visualization saved to {os.path.join(output_dir, filename)}")

def visualize_prediction_correctness(
    reduced_activations: np.ndarray,
    boards_list: list,
    inference_data: Dict[str, Any],
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    Generates a t-SNE plot colored by whether the model's prediction for each
    board state was correct.
    """
    print("Generating prediction correctness plot... ✅❌")
    if not inference_data:
        print("Skipping correctness plot: No inference data provided.")
        return

    correctness_categories = [
        get_model_prediction_correctness(inference_data, b['board'])
        for b in boards_list
    ]

    color_map = {
        True: '#2ca02c',  # Green for Correct
        False: '#d62728', # Red for Incorrect
    }
    
    colors = [color_map[c] for c in correctness_categories]
    
    plt.figure(figsize=(20, 16))
    plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.7, s=25)
    
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=label, markersize=10, markerfacecolor=col) 
               for label, col in [('Correct', color_map[True]), ('Incorrect', color_map[False])]]
    plt.legend(handles=handles, title="Prediction Correctness")
    plt.title(f't-SNE Colored by Prediction Correctness (Layer {layer}, Style: {prompt_style})')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True)
    
    filename = os.path.join(output_dir, f'viz_hypothesis_prediction_correctness_{prompt_style}.png')
    plt.savefig(filename)
    plt.close()
    print(f"Saved prediction correctness plot to {filename}")

def visualize_kmeans_hypothesis(reduced_activations, kmeans_results, analysis_results, output_dir, layer, prompt_style):
    """
    Visualizes K-Means clusters and dynamically annotates the most "interesting" ones.
    An "interesting" cluster is one that is an extreme outlier on some strategic metric
    (e.g., highest win rate, lowest piece count, most fork opportunities). This is
    determined by calculating a Z-score for each metric within each cluster.
    """
    print("Visualizing K-Means hypothesis summary plot...")
    if not analysis_results or len(analysis_results) < 2:
        print("No analysis results to visualize.")
        return

    n_clusters = kmeans_results.n_clusters
    labels = kmeans_results.labels_

    fig, ax = plt.subplots(figsize=(22, 18))
    cmap = plt.get_cmap('tab20b', n_clusters)
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap=cmap, alpha=0.4, s=20)

    # --- Find the Defining Characteristic of Each Cluster via Z-Scores ---
    
    # 1. Calculate global mean and std dev for each metric across all clusters
    metrics = list(next(iter(analysis_results.values())).keys())
    global_stats = {}
    for metric in metrics:
        # FIX 2: Explicitly skip all known non-numeric or non-comparable metrics
        # This prevents the ValueError when calling np.mean() on a list of lists.
        if metric in ['size', 'dominant_line_info', 'top_10_line_patterns']:
            continue
            
        values = [res[metric] for res in analysis_results.values() if metric in res and res[metric] is not None]
        if len(values) > 1:
            global_stats[metric] = {'mean': np.mean(values), 'std': np.std(values)}

    # 2. For each cluster, find its most extreme metric (highest absolute Z-score)
    cluster_defining_features = {}
    for cid, results in analysis_results.items():
        max_z = -1
        defining_feature = None
        for metric, value in results.items():
            if value is None: continue
            if metric in global_stats and global_stats[metric]['std'] > 1e-6:
                mean = global_stats[metric]['mean']
                std = global_stats[metric]['std']
                z_score = (value - mean) / std
                if abs(z_score) > max_z:
                    max_z = abs(z_score)
                    defining_feature = (metric, value, z_score)
        if defining_feature:
            cluster_defining_features[cid] = defining_feature

    # 3. Select the top N most unique clusters to annotate
    num_annotations = 18
    sorted_clusters = sorted(cluster_defining_features.items(), key=lambda item: abs(item[1][2]), reverse=True)
    
    # --- Create human-readable labels and annotate the plot ---
    texts = []
    for cid, (metric, value, z_score) in sorted_clusters[:num_annotations]:
        direction = "Highest" if z_score > 0 else "Lowest"
        metric_name = metric.replace('_', ' ').replace('prop', 'Rate').replace('avg', '').title()
        label_text = f"{direction} {metric_name}\n"
        if "Rate" in label_text or "Ratio" in label_text or "Purity" in label_text:
            label_text += f"({value:.1%})"
        else:
            label_text += f"({value:.2f})"

        cluster_points = reduced_activations[labels == int(cid)]
        if len(cluster_points) > 0:
            centroid = np.mean(cluster_points, axis=0)
            texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                                 bbox=dict(boxstyle="round,pad=0.5", fc="ivory", ec="black", lw=1, alpha=0.9),
                                 fontsize=11, weight='bold'))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", color='black', lw=1.5))

    ax.set_title(f'K-Means Cluster Analysis: Defining Features (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    filename = f'viz_kmeans_hypothesis_summary_{prompt_style}.png'
    plt.savefig(os.path.join(output_dir, filename)); plt.close(fig)
    print(f"K-Means hypothesis visualization saved to {os.path.join(output_dir, filename)}")

def visualize_line_purity_per_cluster(analysis_results, output_dir, layer, prompt_style, algorithm_name):
    """Creates a bar chart for each cluster showing its top 10 line patterns."""
    print(f"Visualizing top 10 line patterns for each {algorithm_name} cluster...")
    purity_plot_dir = os.path.join(output_dir, f'{algorithm_name}_line_purity_plots_{prompt_style}')
    os.makedirs(purity_plot_dir, exist_ok=True)

    LINE_NAMES = {0:"R0", 1:"R1", 2:"R2", 3:"C0", 4:"C1", 5:"C2", 6:"D1", 7:"D2"}
    PIECE_MAP = {1: 'X', -1: 'O', 0: '.'}

    for cid, results in analysis_results.items():
        # *** CHANGE 5: Use 'top_10_line_patterns' key ***
        top_patterns_data = results.get('top_10_line_patterns', [])
        if not top_patterns_data: continue

        num_boards = results['size']
        labels, percentages = [], []
        for (line_idx, pattern_tuple), count in top_patterns_data:
            line_name = LINE_NAMES.get(line_idx, f"L{line_idx}")
            pattern_str = "".join([PIECE_MAP.get(p, '?') for p in pattern_tuple])
            labels.append(f"{line_name}: ({pattern_str})")
            percentages.append((count / num_boards) * 100)

        fig, ax = plt.subplots(figsize=(12, 10)) # Increased height for 10 bars
        bars = ax.barh(labels, percentages, color='c')
        ax.invert_yaxis()
        for bar in bars:
            width = bar.get_width()
            ax.text(width + 0.5, bar.get_y() + bar.get_height()/2, f'{width:.1f}%', va='center')

        ax.set_xlabel('Occurrence Frequency (% of Boards in Cluster)')
        ax.set_title(f'Top 10 Line Patterns in Cluster {cid} (n={num_boards})\nLayer {layer}, Style: {prompt_style}')
        ax.set_xlim(0, max(percentages) * 1.15 if percentages else 1)
        plt.tight_layout()
        plt.savefig(os.path.join(purity_plot_dir, f'viz_line_purity_cluster_{cid}.png'))
        plt.close(fig)

    print(f"Line purity visualizations saved to {purity_plot_dir}")

def load_model(model_path_or_name: str, architecture_name: str) -> HookedTransformer:
    """
    Loads a HuggingFace model into a HookedTransformer object. This function handles
    both local checkpoints and models from the Hub by first loading the raw
    HuggingFace model and then wrapping it in a HookedTransformer.
    """
    print(f"Loading HuggingFace model from: {model_path_or_name}...")

    # STEP 1: Load the raw model using the standard transformers library.
    # When `model_path_or_name` is a local path to your fine-tuned model,
    # this line loads YOUR specific weights from that directory. It does NOT
    # download a new model. This is the correct way to load a local model.
    hf_model = AutoModelForCausalLM.from_pretrained(
        model_path_or_name,
        trust_remote_code=True,
    )

    print(f"Wrapping model in HookedTransformer with architecture: {architecture_name}...")

    # STEP 2: Wrap the loaded model in HookedTransformer.
    # We pass the `hf_model` object (which contains your fine-tuned weights) directly.
    # `HookedTransformer` uses `architecture_name` only to understand the model's
    # structure for applying hooks, not for loading weights.
    model = HookedTransformer.from_pretrained(
        architecture_name,
        hf_model=hf_model,
        device="cpu",  # Load to CPU first to avoid device mismatches during wrapping
        trust_remote_code=True,
    )

    # Move the entire wrapped model (with your fine-tuned weights) to the target device.
    model.to(DEVICE)

    print("Model loaded successfully.")
    return model

def get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, batch_size):
    """
    Gets SAE-reconstructed activations by processing prompts in smaller batches
    to avoid OOM errors.
    """
    all_reconstructed_activations = []
    with torch.no_grad():
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            
            _, cache = model.run_with_cache(batch_prompts, names_filter=sae.cfg.hook_name)
            
            original_activations = cache[sae.cfg.hook_name][:, -1, :].to(DEVICE)
            feature_acts = sae.encode(original_activations)
            
            mask = torch.zeros_like(feature_acts)
            if cluster_indices:
                mask[:, cluster_indices] = 1
            
            reconstructed_activations = sae.decode(feature_acts * mask)
            all_reconstructed_activations.append(reconstructed_activations.cpu())
            
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
    return torch.cat(all_reconstructed_activations, dim=0)


def analyze_and_visualize_grid(model, task, sae, clusters, layer, display_name):
    """Generates PCA plots for all 9 squares of the board."""
    print(f"\n--- VISUALIZING GRID (Per-Square Content) for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]
    
    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    os.makedirs(output_dir, exist_ok=True)

    for i in range(9):
        boards_x = task.find_boards_by_square_state(i, 'X')
        boards_o = task.find_boards_by_square_state(i, 'O')
        boards_empty = task.find_boards_by_square_state(i, 'empty')

        sample_size = min(25, len(boards_x), len(boards_o), len(boards_empty))
        if sample_size < 5: continue
            
        selected_boards = boards_x[:sample_size] + boards_o[:sample_size] + boards_empty[:sample_size]
        prompts = [task.get_prompt(b) for b in selected_boards]
        reconstructed_activations = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
        
        pca = PCA(n_components=2)
        reduced_activations = pca.fit_transform(reconstructed_activations.detach().numpy())
        
        plt.figure(figsize=(10, 8))
        plt.scatter(reduced_activations[:sample_size, 0], reduced_activations[:sample_size, 1], alpha=0.7, label=f'Square {i+1}: X')
        plt.scatter(reduced_activations[sample_size:2*sample_size, 0], reduced_activations[sample_size:2*sample_size, 1], alpha=0.7, label=f'Square {i+1}: O')
        plt.scatter(reduced_activations[2*sample_size:, 0], reduced_activations[2*sample_size:, 1], alpha=0.7, label=f'Square {i+1}: Empty')
        plt.title(f'PCA of SAE Cluster (Layer {layer}) - Colored by State of Square {i+1}')
        plt.legend(); plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'viz_square_{i+1}_content.png')); plt.close()
    print(f"Per-square content visualizations saved to {output_dir}")

def visualize_spatial_grid_representation(model, task, sae, clusters, layer, display_name):
    """Generates a PCA plot of the 'concept vectors' for each of the 9 squares."""
    print(f"\n--- VISUALIZING SPATIAL GEOMETRY of the grid for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    square_concept_vectors = []
    valid_square_indices = []
    for i in range(9):
        mean_vectors_for_square = []
        for piece in ['X', 'O', 'empty']:
            boards = task.find_boards_by_square_state(i, piece)
            if len(boards) > 5:
                sample_size = min(50, len(boards))
                prompts = [task.get_prompt(b) for b in random.sample(boards, sample_size)]
                recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
                mean_vectors_for_square.append(recons.mean(dim=0))

        if not mean_vectors_for_square:
            print(f"Skipping spatial vector for square {i+1}: Not enough data.")
            continue
        
        valid_square_indices.append(i + 1)
        square_concept_vectors.append(torch.stack(mean_vectors_for_square).mean(dim=0))

    if len(square_concept_vectors) < 3:
        print("Could not generate enough concept vectors. Aborting spatial visualization.")
        return

    all_concepts_tensor = torch.stack(square_concept_vectors)
    pca = PCA(n_components=2)
    reduced_concepts = pca.fit_transform(all_concepts_tensor.cpu().detach().numpy())
    
    plt.figure(figsize=(10, 10))
    plt.scatter(reduced_concepts[:, 0], reduced_concepts[:, 1], s=120, c=valid_square_indices, cmap='viridis')
    
    for i, square_num in enumerate(valid_square_indices):
        plt.annotate(
            f"Square {square_num}", xy=(reduced_concepts[i, 0], reduced_concepts[i, 1]),
            xytext=(15, 15), textcoords='offset points', ha='center',
            arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2")
        )

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_spatial_grid_geometry.png')

    plt.title(f'PCA of Spatial Concept Vectors for Each Square (Layer {layer})')
    plt.xlabel('Principal Component 1'); plt.ylabel('Principal Component 2'); plt.grid(True)
    plt.savefig(filename); plt.close()

    print(f"Spatial grid geometry visualization saved to {filename}")

def visualize_decomposed_spatial_concepts(model, task, sae, clusters, layer, display_name):
    """
    NEW: Generates a PCA plot of 27 concept vectors: one for each piece type (X, O, empty)
    on each of the 9 squares.
    """
    print(f"\n--- VISUALIZING DECOMPOSED SPATIAL CONCEPTS for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    decomposed_vectors = []
    plot_labels = []

    for i in range(9): # For each square
        for piece in ['X', 'O', 'empty']: # For each piece type
            boards = task.find_boards_by_square_state(i, piece)
            if len(boards) > 5:
                sample_size = min(50, len(boards))
                prompts = [task.get_prompt(b) for b in random.sample(boards, sample_size)]
                recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
                
                decomposed_vectors.append(recons.mean(dim=0))
                plot_labels.append({'square': i + 1, 'piece': piece})

    if len(decomposed_vectors) < 3:
        print("Could not generate enough decomposed concept vectors. Aborting visualization.")
        return

    all_decomposed_tensor = torch.stack(decomposed_vectors)
    pca = PCA(n_components=2)
    reduced_decomposed = pca.fit_transform(all_decomposed_tensor.cpu().detach().numpy())
    
    plt.figure(figsize=(15, 15))
    
    markers = {'X': 'x', 'O': 'o', 'empty': '.'}
    colors = plt.cm.get_cmap('tab10', 9)

    for i, label in enumerate(plot_labels):
        plt.scatter(
            reduced_decomposed[i, 0], reduced_decomposed[i, 1],
            marker=markers[label['piece']],
            color=colors(label['square'] - 1),
            s=150,
            label=f"Sq {label['square']} {label['piece']}" if i < 27 else ""
        )

    for i, label in enumerate(plot_labels):
        plt.text(reduced_decomposed[i, 0] * 1.05, reduced_decomposed[i, 1] * 1.05, 
                 f"{label['square']}{label['piece'][0]}", fontsize=9)

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_decomposed_concepts.png')

    plt.title(f'PCA of Decomposed Spatial & Content Vectors (Layer {layer})')
    plt.xlabel('Principal Component 1'); plt.ylabel('Principal Component 2'); plt.grid(True)
    
    plt.savefig(filename); plt.close()
    print(f"Decomposed spatial concepts visualization saved to {filename}")


def run_causal_interventions_for_grid(model, task, sae, clusters, layer):
    # This function remains unchanged
    print(f"\n--- INTERVENING ON GRID for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]
    for i in range(9):
        pair = task.find_board_pair_differing_at_square(i)
        if not pair: continue
        dirty_board, clean_board = pair
        dirty_prompt = task.get_prompt(dirty_board)
        clean_prompt = task.get_prompt(clean_board)
        with torch.no_grad():
            _, clean_cache = model.run_with_cache(clean_prompt, names_filter=sae.cfg.hook_name)
        clean_activation = clean_cache[sae.cfg.hook_name][0, -1, :].to(DEVICE)
        clean_feature_acts = sae.encode(clean_activation)
        def patch_hook(resid_post, hook):
            dirty_feature_acts = sae.encode(resid_post[0, -1, :])
            dirty_feature_acts[cluster_indices] = clean_feature_acts[cluster_indices]
            resid_post[0, -1, :] = sae.decode(dirty_feature_acts)
            return resid_post
        with torch.no_grad():
            original_logits = model(dirty_prompt)
            patched_logits = model.run_with_hooks(dirty_prompt, fwd_hooks=[(sae.cfg.hook_name, patch_hook)])
        def logits_to_best_move(logits):
            last_token_logits = logits[0, -1, :]
            move_token_ids = [model.to_tokens(f" {n}")[0, -1].item() for n in range(1, 10)]
            move_log_probs = last_token_logits[move_token_ids]
            return str(torch.argmax(move_log_probs).item() + 1)
        original_move = logits_to_best_move(original_logits)
        patched_move = logits_to_best_move(patched_logits)
        clean_gt_move = task.get_best_move_str(clean_board)
        print(f"\n--- Intervention for Square {i+1} ---")
        print(f"Original Move: {original_move} | Patched Move: {patched_move} | Expected Move: {clean_gt_move}")
        if patched_move == clean_gt_move and original_move != clean_gt_move:
            print("  -> SUCCESS: Causal link established.")
        else:
            print("  -> FAILED: No clear causal link found.")
            
def create_relabeled_plot(reduced_activations, sampled_boards, colors, base_title, label_type, output_dir, layer):
    """
    Creates a new plot by adding move labels (legal or best) to an existing t-SNE visualization.
    Uses adjust_text to prevent labels from overlapping.

    Args:
        label_type (str): Either 'legal' or 'best'.
    """
    print(f"Re-labeling plot '{base_title}' with {label_type} moves...")
    
    # Use object-oriented plotting with fig and ax for more control
    fig, ax = plt.subplots(figsize=(20, 16))
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.7, s=20)

    # To avoid clutter, we'll only label a random subset of points
    num_labels_to_add = 75
    if len(sampled_boards) > num_labels_to_add:
        indices_to_label = random.sample(range(len(sampled_boards)), num_labels_to_add)
    else:
        indices_to_label = range(len(sampled_boards))

    texts = []
    for i in indices_to_label:
        board = sampled_boards[i]
        label_text = ""
        if label_type == 'legal':
            if not board['is_terminal']:
                legal_squares = [str(sq + 1) for sq, p in enumerate(board['board']) if p == 0]
                label_text = f"L:{','.join(legal_squares)}"
        
        elif label_type == 'best':
            if not board['is_terminal']:
                is_p1_turn = board['board'].count(1) == board['board'].count(2)
                # Correctly use 'best_moves' and convert tokens to square numbers
                best_squares = [str(m) if is_p1_turn else str(m - 9) for m in board.get('best_moves', [])]
                if best_squares:
                    # Sort for consistency in labels
                    best_squares.sort(key=int)
                    label_text = f"B:{','.join(best_squares)}"

        if label_text:
            texts.append(ax.text(reduced_activations[i, 0], reduced_activations[i, 1], label_text,
                                 ha='center', va='center', fontsize=7,
                                 bbox=dict(boxstyle="round,pad=0.2", fc='white', ec='black', lw=0.5, alpha=0.8)))

    # Automatically adjust text positions to prevent overlap
    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="-", color='gray', lw=0.5, alpha=0.7))

    new_title = f"{base_title} - Labeled by Next {label_type.capitalize()} Moves (Layer {layer})"
    ax.set_title(new_title)
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2')
    ax.grid(True)
    
    filename = f"viz_relabeled_{base_title.lower().replace(' ', '_')}_{label_type}_moves.png"
    plt.savefig(os.path.join(output_dir, filename))
    plt.close(fig)
    
# We want to plot the line patterns on the clusters so that we can see how the clusters relate to the strategic patterns.
def visualize_dominant_line_patterns(
    reduced_activations: np.ndarray,
    labels: np.ndarray,
    analysis_results: Dict[str, Any],
    output_dir: str,
    layer: int,
    prompt_style: str,
    algorithm_name: str
):
    """
    NEW: Visualizes clusters and annotates them with their dominant line patterns.
    Dominant patterns are the smallest set of line patterns that cumulatively
    account for at least 90% of the boards in that cluster.
    """
    print(f"Visualizing dominant line patterns for {algorithm_name} clusters...")
    if not analysis_results:
        print("No analysis results to visualize.")
        return

    n_clusters = len(analysis_results)
    fig, ax = plt.subplots(figsize=(24, 20))
    cmap = plt.get_cmap('tab20b', n_clusters)
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=labels, cmap=cmap, alpha=0.25, s=20)

    # --- Annotation Logic ---
    LINE_NAMES = {0:"R0", 1:"R1", 2:"R2", 3:"C0", 4:"C1", 5:"C2", 6:"D1", 7:"D2"}
    PIECE_MAP = {1: 'X', -1: 'O', 0: '.'}
    
    texts = []
    for cid, results in analysis_results.items():
        cluster_size = results.get('size')
        top_patterns = results.get('top_10_line_patterns', [])
        if not cluster_size or not top_patterns:
            continue

        purity_threshold = cluster_size * 0.90
        cumulative_count = 0
        dominant_patterns_str = []

        # --- FIX: Use tuple unpacking to correctly access the data ---
        for pattern_info in top_patterns:
            # The data is a tuple: ((line_index, pattern_tuple), count)
            (line_index, pattern_tuple), count = pattern_info

            if cumulative_count >= purity_threshold:
                break
            
            cumulative_count += count
            
            line_name = LINE_NAMES.get(line_index, f"L{line_index}")
            pattern_str = "".join([PIECE_MAP.get(p, '?') for p in pattern_tuple])
            purity = (count / cluster_size) * 100
            dominant_patterns_str.append(f"{line_name}:({pattern_str}) [{purity:.0f}%]")

        if not dominant_patterns_str:
            continue

        cluster_points = reduced_activations[labels == int(cid)]
        if len(cluster_points) == 0:
            continue
        centroid = np.mean(cluster_points, axis=0)

        label_text = "\n".join(dominant_patterns_str)
        texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                             bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="black", lw=1, alpha=0.9),
                             fontsize=9, weight='bold'))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->", color='gray', lw=1.0, alpha=0.8))

    ax.set_title(f'{algorithm_name} Clusters Labeled by Dominant Line Patterns (>90% Purity)\n(Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1')
    ax.set_ylabel('t-SNE Dimension 2')
    ax.grid(True)
    
    filename = f'viz_{algorithm_name}_dominant_line_patterns_{prompt_style}.png'
    plt.savefig(os.path.join(output_dir, filename))
    plt.close(fig)
    print(f"Dominant line pattern visualization saved to {os.path.join(output_dir, filename)}")
    
def run_and_visualize_clustering_analysis(
    algorithm_name: str,
    clusters: np.ndarray,
    reduced_activations: np.ndarray,
    final_boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """Helper to run the analysis pipeline for a given clustering result."""
    n_clusters = len(set(clusters)) - (1 if -1 in clusters else 0)
    print(f"\n--- Analyzing clusters from {algorithm_name} ({n_clusters} clusters found) ---")

    # 1. Save clustered boards
    clustered_boards = defaultdict(list)
    label_key = 'ascii_board' if prompt_style == 'ascii_board' else 'text_instruction'
    for i, label in enumerate(clusters):
        if label != -1: # Skip noise points
            clustered_boards[str(label)].append(final_boards_list[i].get(label_key, ''))

    json_path = os.path.join(output_dir, f'{algorithm_name}_clusters_{prompt_style}.json')
    with open(json_path, 'w') as f:
        json.dump(clustered_boards, f, indent=4)
    print(f"Saved {algorithm_name} cluster data to {json_path}")

    # 2. Run statistical analysis
    tester = TicTacToeClusterHypothesisTester(json_path)
    analysis_results = tester.run_statistical_tests()

    if analysis_results:
        save_analysis_results(analysis_results, output_dir, layer, prompt_style, algorithm_name)
        visualize_line_purity_per_cluster(analysis_results, output_dir, layer, prompt_style, algorithm_name)
        
        # *** CHANGE 2: Call the new global extremes visualization function ***
        visualize_cluster_extremes(
            reduced_activations, clusters, analysis_results, output_dir, layer, prompt_style, algorithm_name
        )
        
        visualize_cluster_extremes_normalized(
            reduced_activations, clusters, analysis_results, output_dir, layer, prompt_style, algorithm_name
        )
        
        visualize_dominant_line_patterns(
            reduced_activations, clusters, analysis_results, output_dir, layer, prompt_style, algorithm_name
        )

# *** CHANGE 4: New visualization for detailed strategic situations ***
def visualize_strategic_situation(reduced_activations, boards_list, output_dir, layer, prompt_style):
    """Generates a t-SNE plot colored by the detailed strategic value of each board state."""
    print("Generating strategic situation plot... ♟️")
    
    categories, piece_map = [], {1: 'X', 2: 'O'}
    
    # Memoization for minimax scores to speed up calculations
    memoized_scores = {}
    
    def get_board_threats(board, player):
        """Checks if the opponent has an immediate winning threat."""
        threat_count = 0
        lines = [(0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8), (2,4,6)]
        for line in lines:
            pieces = [board[i] for i in line]
            if pieces.count(player) == 2 and pieces.count(0) == 1:
                threat_count += 1
        return threat_count

    for board_data in boards_list:
        board = board_data['board']
        board_tuple = tuple(board)
        
        if board_data['is_terminal']:
            if board_data['winner'] == 1: categories.append('X Won')
            elif board_data['winner'] == 2: categories.append('O Won')
            else: categories.append('Draw')
            continue

        p1_moves = board.count(1); p2_moves = board.count(2)
        current_player = 1 if p1_moves == p2_moves else 2
        
        if board_tuple not in memoized_scores:
            memoized_scores[board_tuple] = _minimax_get_score(board, current_player)[0]
        score = memoized_scores[board_tuple]

        if score > 5: categories.append('X Guaranteed Win')
        elif score < -5: categories.append('O Guaranteed Win')
        else:
            opponent = 2 if current_player == 1 else 1
            if get_board_threats(board, opponent) > 0:
                categories.append(f'{piece_map[current_player]} Must Block')
            else:
                categories.append(f'{piece_map[current_player]} to Play')

    color_map = {
        'X Won': '#8B0000', 'O Won': '#00008B', 'Draw': '#808080',
        'X Guaranteed Win': '#DC143C', 'O Guaranteed Win': '#4169E1',
        'X Must Block': '#F08080', 'O Must Block': '#87CEFA',
        'X to Play': '#FFA07A', 'O to Play': '#ADD8E6'
    }
    
    colors = [color_map.get(c, 'black') for c in categories]
    
    plt.figure(figsize=(20, 16))
    plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.8, s=25)
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in color_map.items()]
    plt.legend(handles=handles, title="Strategic Situation")
    plt.title(f't-SNE of Board States by Strategic Situation (Layer {layer}, Style: {prompt_style})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
    plt.savefig(os.path.join(output_dir, f'viz_hypothesis_strategic_situation_{prompt_style}.png')); plt.close()


# Code for heirarchical breakdown of the clusters till every subcluster has > 90% purity for a single line pattern. 

def visualize_hierarchical_breakdown(
    reduced_activations: np.ndarray,
    l0_labels: np.ndarray,
    l0_analysis_results: Dict[str, Any], # New argument for L0 annotations
    leaf_cluster_results: Dict[str, Any],
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    NEW: Creates a single, combined plot showing the hierarchical breakdown.
    - L0 clusters are outlined and labeled with their dominant (>90%) patterns.
    - Leaf sub-clusters within each L0 are colored distinctly and labeled with their pure pattern.
    """
    print("Visualizing combined hierarchical breakdown plot...")
    
    fig, ax = plt.subplots(figsize=(30, 24)) # A larger figure for more detail
    
    l0_cluster_ids = np.unique(l0_labels)
    l0_cmap = plt.get_cmap('tab20', len(l0_cluster_ids)) # Base colors for L0 outlines
    
    # Group leaf clusters by their L0 parent for efficient plotting
    parent_to_leaf_map = defaultdict(list)
    for leaf_id in leaf_cluster_results.keys():
        parent_id = leaf_id.split('-')[0]
        parent_to_leaf_map[parent_id].append(leaf_id)

    texts = []
    LINE_NAMES = {0:"R0", 1:"R1", 2:"R2", 3:"C0", 4:"C1", 5:"C2", 6:"D1", 7:"D2"}
    PIECE_MAP = {1: 'X', -1: 'O', 0: '.'}

    # --- Main Plotting Loop: Iterate through each L0 cluster ---
    for i, l0_cid_str in enumerate(l0_analysis_results.keys()):
        l0_cid = int(l0_cid_str)
        l0_points_mask = (l0_labels == l0_cid)
        l0_points = reduced_activations[l0_points_mask]
        
        if len(l0_points) == 0: continue

        # 1. Draw the convex hull for the L0 cluster boundary
        # --- FIX: Wrap ConvexHull in a try...except block ---
        try:
            if len(l0_points) > 2:
                hull = ConvexHull(l0_points)
                for simplex in hull.simplices:
                    ax.plot(l0_points[simplex, 0], l0_points[simplex, 1], color=l0_cmap(i), lw=3.0, alpha=0.7, zorder=1)
        except Exception as e:
            print(f"⚠️ Warning: Could not compute Convex Hull for L0 cluster {l0_cid} due to error {e}. Skipping boundary drawing.")
            # As a fallback, just plot the points without a boundary
            ax.scatter(l0_points[:, 0], l0_points[:, 1], color=l0_cmap(i), alpha=0.1, s=20)

        # 2. Annotate the L0 cluster with its dominant patterns (>90% purity)
        l0_results = l0_analysis_results[l0_cid_str]
        cluster_size = l0_results.get('size')
        top_patterns = l0_results.get('top_10_line_patterns', [])
        
        if cluster_size and top_patterns:
            purity_threshold = cluster_size * 0.90
            cumulative_count = 0
            dominant_patterns_str = []
            for (line_index, pattern_tuple), count in top_patterns:
                if cumulative_count >= purity_threshold: break
                cumulative_count += count
                line_name = LINE_NAMES.get(line_index, f"L{line_index}")
                pattern_str = "".join([PIECE_MAP.get(p, '?') for p in pattern_tuple])
                purity = (count / cluster_size) * 100
                dominant_patterns_str.append(f"{line_name}:({pattern_str}) [{purity:.0f}%]")
            
            l0_centroid = np.mean(l0_points, axis=0)
            label_text = f"L0 Cluster {l0_cid}\n" + "\n".join(dominant_patterns_str)
            texts.append(ax.text(l0_centroid[0], l0_centroid[1], label_text, ha='center', va='center',
                                 bbox=dict(boxstyle="round,pad=0.5", fc=l0_cmap(i), ec="black", lw=1.5, alpha=0.8),
                                 fontsize=10, weight='bold', color='white'))

        # 3. Plot and annotate the leaf sub-clusters within this L0 cluster
        leaf_ids = parent_to_leaf_map.get(l0_cid_str, [])
        leaf_cmap = plt.get_cmap('viridis', len(leaf_ids)) if len(leaf_ids) > 1 else [l0_cmap(i)]
        
        for j, leaf_id in enumerate(leaf_ids):
            leaf_data = leaf_cluster_results[leaf_id]
            leaf_indices = leaf_data['indices']
            leaf_points = reduced_activations[leaf_indices]
            
            leaf_color = leaf_cmap(j) if len(leaf_ids) > 1 else leaf_cmap[0]
            ax.scatter(leaf_points[:, 0], leaf_points[:, 1], color=leaf_color, s=25, alpha=0.6, zorder=2)
            
            # Only add a separate label for a leaf if it's a true sub-cluster
            if leaf_id != l0_cid_str:
                leaf_centroid = np.mean(leaf_points, axis=0)
                purity_percent = leaf_data['dominant_line_purity'] * 100
                pattern_str = leaf_data['dominant_line_str']
                label_text = f"ID: {leaf_id}\n{pattern_str} [{purity_percent:.0f}%]"
                texts.append(ax.text(leaf_centroid[0], leaf_centroid[1], label_text, ha='center', va='center',
                                     bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=0.8, alpha=0.9),
                                     fontsize=8))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="-", color='gray', lw=0.5, alpha=0.7))
    
    ax.set_title(f'Combined Hierarchical Breakdown (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    
    filename = os.path.join(output_dir, f'viz_hierarchical_combined_{prompt_style}.png')
    plt.savefig(filename); plt.close(fig)
    print(f"Saved combined hierarchical plot to {filename}")

def run_hierarchical_analysis(
    reduced_activations: np.ndarray,
    final_boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    MODIFIED: Kicks off the hierarchical clustering process and gathers analysis
    at both the L0 and leaf levels for a combined visualization.
    """
    print("\n--- Running Hierarchical Purity-Based Clustering ---")
    
    # --- Level-0 Clustering (initial breakdown) ---
    n_clusters_l0 = 18
    kmeans_l0 = KMeans(n_clusters=n_clusters_l0, random_state=42, n_init='auto').fit(reduced_activations)
    l0_labels = kmeans_l0.labels_

    # --- NEW: Analyze L0 clusters BEFORE recursion to get their dominant patterns ---
    print("Analyzing Level-0 clusters for overview plot...")
    l0_clustered_boards = defaultdict(list)
    for i, label in enumerate(l0_labels):
        l0_clustered_boards[str(label)].append(final_boards_list[i].get('ascii_board', ''))
    
    l0_json_string = json.dumps(l0_clustered_boards)
    l0_tester = TicTacToeClusterHypothesisTester(l0_json_string)
    l0_analysis_results = l0_tester.analysis_results

    leaf_cluster_results = {}
    
    # --- Recursive Breakdown Function (This inner function is unchanged) ---
    def _recursive_breakdown(
        current_indices: np.ndarray,
        parent_id: str,
        depth: int
    ):
        """Recursively subdivides clusters until they are pure."""
        MAX_DEPTH = 4 
        MIN_CLUSTER_SIZE = 20
        PURITY_THRESHOLD = 0.90

        if depth >= MAX_DEPTH or len(current_indices) < MIN_CLUSTER_SIZE:
            boards_in_cluster = [final_boards_list[i] for i in current_indices]
            cluster_dict = {'0': [b.get('ascii_board', '') for b in boards_in_cluster]}
            tester = TicTacToeClusterHypothesisTester(json.dumps(cluster_dict))
            analysis = tester.analysis_results.get('0', {})
            
            leaf_cluster_results[parent_id] = {
                'indices': current_indices,
                'dominant_line_purity': analysis.get('dominant_line_purity', 0),
                'dominant_line_str': tester.LINE_NAMES.get(analysis['top_10_line_patterns'][0][0][0], '') + ": " + "".join([tester.PIECE_MAP.get(p, '?') for p in analysis['top_10_line_patterns'][0][0][1]]) if analysis.get('top_10_line_patterns') else "N/A"
            }
            return

        boards_in_cluster = [final_boards_list[i] for i in current_indices]
        cluster_dict = {'0': [b.get('ascii_board', '') for b in boards_in_cluster]}
        tester = TicTacToeClusterHypothesisTester(json.dumps(cluster_dict))
        analysis = tester.analysis_results.get('0', {})
        purity = analysis.get('dominant_line_purity', 0)

        if purity >= PURITY_THRESHOLD:
            leaf_cluster_results[parent_id] = {
                'indices': current_indices,
                'dominant_line_purity': purity,
                'dominant_line_str': tester.LINE_NAMES.get(analysis['top_10_line_patterns'][0][0][0], '') + ": " + "".join([tester.PIECE_MAP.get(p, '?') for p in analysis['top_10_line_patterns'][0][0][1]]) if analysis.get('top_10_line_patterns') else "N/A"
            }
            return

        print(f"Subdividing cluster {parent_id} (Purity: {purity:.2f}, Size: {len(current_indices)}) at depth {depth}...")
        activations_subset = reduced_activations[current_indices]
        kmeans_sub = KMeans(n_clusters=3, random_state=depth, n_init='auto').fit(activations_subset)
        
        for sub_cid in range(3):
            sub_cluster_mask = (kmeans_sub.labels_ == sub_cid)
            sub_cluster_indices = current_indices[sub_cluster_mask]
            
            if len(sub_cluster_indices) > 0:
                _recursive_breakdown(
                    current_indices=sub_cluster_indices,
                    parent_id=f"{parent_id}-{sub_cid}",
                    depth=depth + 1
                )

    # --- Kick off the recursion for each Level-0 cluster ---
    for cid in range(n_clusters_l0):
        l0_cluster_indices = np.where(l0_labels == cid)[0]
        if len(l0_cluster_indices) > 0:
            _recursive_breakdown(
                current_indices=l0_cluster_indices,
                parent_id=str(cid),
                depth=1
            )
            
    # --- Final Visualization ---
    visualize_hierarchical_breakdown(
        reduced_activations,
        l0_labels,
        l0_analysis_results, # Pass the new L0 analysis
        leaf_cluster_results,
        output_dir,
        layer,
        prompt_style
    )
    
# Agglomerative approach
def visualize_agglomerated_clusters(
    reduced_activations: np.ndarray,
    final_labels: np.ndarray,
    concept_definitions: Dict[int, str],
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    NEW: Visualizes the final, agglomerated concept clusters. Each colored region
    represents a pure, auto-discovered strategic concept.
    """
    print("Visualizing agglomerated concept clusters...")
    
    fig, ax = plt.subplots(figsize=(28, 22))
    
    unique_concept_ids = sorted(list(concept_definitions.keys()))
    cmap = plt.get_cmap('tab20', len(unique_concept_ids))
    
    # Plot all the non-clustered points as a grey background
    noise_mask = (final_labels == -1)
    ax.scatter(reduced_activations[noise_mask, 0], reduced_activations[noise_mask, 1], c='lightgray', s=10, alpha=0.3, zorder=1)
    
    texts = []
    for i, concept_id in enumerate(unique_concept_ids):
        concept_mask = (final_labels == concept_id)
        points = reduced_activations[concept_mask]
        
        if len(points) > 0:
            # Plot the points for this concept cluster
            ax.scatter(points[:, 0], points[:, 1], color=cmap(i), s=25, alpha=0.7, zorder=2)
            
            # Draw a convex hull boundary
            if len(points) > 2:
                try:
                    hull = ConvexHull(points)
                    for simplex in hull.simplices:
                        ax.plot(points[simplex, 0], points[simplex, 1], color=cmap(i), lw=3.0, alpha=0.85, zorder=3)
                except Exception:
                    print(f"⚠️ Warning: Could not compute Convex Hull for concept ID {concept_id}.")

            # Add the label
            centroid = np.mean(points, axis=0)
            label_text = concept_definitions[concept_id]
            texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                                 bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="black", lw=1.5, alpha=0.9),
                                 fontsize=11, weight='bold', zorder=4))

    if texts:
        adjust_text(texts, ax=ax, force_points=(0.5, 0.5), expand_points=(1.5, 1.5),
                    arrowprops=dict(arrowstyle="-", color='black', lw=1.0, alpha=0.8))

    ax.set_title(f'Agglomerative Concept Clustering (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    
    filename = os.path.join(output_dir, f'viz_agglomerative_concepts_{prompt_style}.png')
    plt.savefig(filename); plt.close(fig)
    print(f"Saved agglomerative concept plot to {filename}")


# Pure agglomerative clustering based on dominant line patterns.
def run_agglomerative_analysis(
    reduced_activations: np.ndarray,
    final_boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    NEW: Performs an agglomerative analysis by finding pure micro-clusters
    and merging them based on their dominant line pattern.
    """
    print("\n--- Running Agglomerative Purity-Based Analysis ---")
    
    # 1. High-granularity initial clustering to find "micro-concepts"
    n_micro_clusters = 100
    kmeans = KMeans(n_clusters=n_micro_clusters, random_state=42, n_init='auto').fit(reduced_activations)
    
    # 2. Analyze each micro-cluster to find its dominant pattern and purity
    pattern_to_micro_clusters = defaultdict(list)
    micro_cluster_profiles = {}
    PURITY_THRESHOLD = 0.70
    
    print(f"Analyzing {n_micro_clusters} micro-clusters to find pure concepts...")
    for cid in range(n_micro_clusters):
        indices = np.where(kmeans.labels_ == cid)[0]
        if len(indices) < 10: continue # Skip tiny, noisy clusters

        boards_in_cluster = [final_boards_list[i] for i in indices]
        cluster_dict = {str(cid): [b.get('ascii_board', '') for b in boards_in_cluster]}
        tester = TicTacToeClusterHypothesisTester(json.dumps(cluster_dict))
        analysis = tester.analysis_results.get(str(cid), {})
        
        purity = analysis.get('dominant_line_purity', 0)
        
        if purity >= PURITY_THRESHOLD and analysis.get('top_10_line_patterns'):
            # This is a pure micro-cluster, profile it
            top_pattern_info = analysis['top_10_line_patterns'][0]
            (line_index, pattern_tuple), count = top_pattern_info
            
            # Create a canonical string representation for the pattern
            line_name = tester.LINE_NAMES.get(line_index, f"L{line_index}")
            pattern_str = "".join([tester.PIECE_MAP.get(p, '?') for p in pattern_tuple])
            dominant_pattern_key = f"{line_name}:({pattern_str})"
            
            pattern_to_micro_clusters[dominant_pattern_key].append(cid)
            micro_cluster_profiles[cid] = {'indices': indices}

    # 3. Agglomerate: Create final concept clusters by merging
    final_labels = np.full(reduced_activations.shape[0], -1, dtype=int) # -1 for noise/impure
    concept_definitions = {}
    next_concept_id = 0
    
    print("Agglomerating pure micro-clusters into final concept clusters...")
    for pattern_key, micro_cids in pattern_to_micro_clusters.items():
        # All micro-clusters in this list share the same dominant pattern
        # Assign them all the same new concept ID
        concept_definitions[next_concept_id] = pattern_key
        for mcid in micro_cids:
            indices_to_label = micro_cluster_profiles[mcid]['indices']
            final_labels[indices_to_label] = next_concept_id
        next_concept_id += 1
        
    # 4. Visualize the final, merged concept clusters
    visualize_agglomerated_clusters(
        reduced_activations,
        final_labels,
        concept_definitions,
        output_dir,
        layer,
        prompt_style
    )

# Hybrid heirarchical-agglomerative clustering

import matplotlib.patches as mpatches # Add this import if you don't have it

def visualize_hybrid_breakdown(
    reduced_activations: np.ndarray,
    l0_labels: np.ndarray,
    l0_analysis_results: Dict[str, Any],
    l0_to_sub_clusters_map: Dict[str, List[Dict[str, Any]]],
    output_dir: str,
    layer: int,
    prompt_style: str
):
    """
    UPGRADED: Visualizes the hybrid analysis with a clean plot and a detailed side legend for L0 clusters.
    """
    print("Visualizing hybrid hierarchical breakdown plot with legend...")
    
    fig, ax = plt.subplots(figsize=(30, 24))
    
    l0_cluster_ids = np.unique(l0_labels)
    l0_cmap = plt.get_cmap('tab20', len(l0_cluster_ids))
    
    texts = []
    legend_handles = []
    LINE_NAMES = {0:"R0", 1:"R1", 2:"R2", 3:"C0", 4:"C1", 5:"C2", 6:"D1", 7:"D2"}
    PIECE_MAP = {1: 'X', -1: 'O', 0: '.'}

    # --- Main Plotting Loop: Iterate through each L0 cluster ---
    for i, l0_cid_str in enumerate(l0_analysis_results.keys()):
        l0_cid = int(l0_cid_str)
        l0_points_mask = (l0_labels == l0_cid)
        l0_points = reduced_activations[l0_points_mask]
        
        if len(l0_points) < 3: continue

        # 1. Draw the convex hull for the L0 cluster boundary
        try:
            hull = ConvexHull(l0_points)
            for simplex in hull.simplices:
                ax.plot(l0_points[simplex, 0], l0_points[simplex, 1], color=l0_cmap(i), lw=3.0, alpha=0.7, zorder=1)
        except Exception as e:
            print(f"⚠️ Warning: Could not compute Convex Hull for L0 cluster {l0_cid} due to error {e}.")

        # 2. Create the legend entry for this L0 cluster
        l0_results = l0_analysis_results[l0_cid_str]
        cluster_size = l0_results.get('size')
        top_patterns = l0_results.get('top_10_line_patterns', [])
        
        dominant_patterns_str = []
        if cluster_size and top_patterns:
            purity_threshold = cluster_size * 0.90
            cumulative_count = 0
            for (line_index, pattern_tuple), count in top_patterns:
                if cumulative_count >= purity_threshold: break
                cumulative_count += count
                line_name = LINE_NAMES.get(line_index, f"L{line_index}")
                pattern_str = "".join([PIECE_MAP.get(p, '?') for p in pattern_tuple])
                purity = (count / cluster_size) * 100
                dominant_patterns_str.append(f"{line_name}:({pattern_str}) [{purity:.0f}%]")
        
        legend_label = f"L0 Cluster {l0_cid}\n" + "\n".join(dominant_patterns_str)
        legend_handles.append(mpatches.Patch(color=l0_cmap(i), label=legend_label))

        # 3. Plot and annotate the pure sub-clusters found within this L0 cluster
        sub_clusters = l0_to_sub_clusters_map.get(str(l0_cid), [])
        sub_cmap = plt.get_cmap('viridis', len(sub_clusters) if len(sub_clusters) > 0 else 1)
        
        # Plot all points of the L0 cluster with a very light background color
        ax.scatter(l0_points[:, 0], l0_points[:, 1], color=l0_cmap(i), alpha=0.05, s=15, zorder=2)

        for j, sub_cluster_data in enumerate(sub_clusters):
            sub_indices = sub_cluster_data['indices']
            sub_points = reduced_activations[sub_indices]
            
            if len(sub_points) > 0:
                ax.scatter(sub_points[:, 0], sub_points[:, 1], color=sub_cmap(j), s=25, alpha=0.8, zorder=3)
                
                centroid = np.mean(sub_points, axis=0)
                label_text = sub_cluster_data['label']
                texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                                     bbox=dict(boxstyle="round,pad=0.4", fc="white", ec="black", lw=1, alpha=0.9),
                                     fontsize=9, weight='bold', zorder=4))

    if texts:
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="-", color='gray', lw=0.5, alpha=0.7))
    
    # --- Add the custom legend to the plot ---
    ax.legend(handles=legend_handles, title="L0 Cluster Dominant Patterns (>90%)",
              loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0.,
              fontsize='small', title_fontsize='medium')
    
    ax.set_title(f'Hybrid Hierarchical-Agglomerative Analysis (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
    
    # Adjust layout to make room for the legend
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    
    filename = os.path.join(output_dir, f'viz_hybrid_agglomerative_{prompt_style}.png')
    plt.savefig(filename); plt.close(fig)
    print(f"Saved hybrid analysis plot to {filename}")

def run_hybrid_hierarchical_agglomerative_analysis(
    reduced_activations: np.ndarray,
    final_boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str,
    return_components: bool = False
):
    """
    MODIFIED: Performs the hybrid analysis and gathers analysis
    at both the L0 and sub-cluster levels for a combined visualization.
    """
    print("\n--- Running Hybrid Hierarchical-Agglomerative Analysis ---")
    
    # 1. Level-0 Clustering (initial high-level breakdown)
    n_clusters_l0 = 18
    kmeans_l0 = KMeans(n_clusters=n_clusters_l0, random_state=42, n_init='auto').fit(reduced_activations)
    l0_labels = kmeans_l0.labels_

    # --- NEW: Analyze L0 clusters BEFORE recursion to get their dominant patterns for the legend ---
    print("Analyzing Level-0 clusters for overview plot legend...")
    l0_clustered_boards = defaultdict(list)
    for i, label in enumerate(l0_labels):
        l0_clustered_boards[str(label)].append(final_boards_list[i].get('ascii_board', ''))
    
    l0_json_string = json.dumps(l0_clustered_boards)
    l0_tester = TicTacToeClusterHypothesisTester(l0_json_string)
    l0_analysis_results = l0_tester.analysis_results # This will be passed to the visualizer

    l0_to_sub_clusters_map = defaultdict(list)
    PURITY_THRESHOLD = 0.70 # Purity for defining a pure sub-concept

    # 2. Loop through each L0 cluster and perform agglomeration inside it
    for l0_cid in range(n_clusters_l0):
        l0_indices = np.where(l0_labels == l0_cid)[0]
        if len(l0_indices) < 20: continue

        print(f"\nAnalyzing inside L0 Cluster {l0_cid} (size: {len(l0_indices)})...")
        l0_activations = reduced_activations[l0_indices]
        
        n_micro_clusters = max(5, min(30, len(l0_indices) // 10))
        kmeans_micro = KMeans(n_clusters=n_micro_clusters, random_state=42, n_init='auto').fit(l0_activations)
        
        pattern_to_indices = defaultdict(list)

        for micro_cid in range(n_micro_clusters):
            micro_mask = (kmeans_micro.labels_ == micro_cid)
            micro_indices_local = np.where(micro_mask)[0]
            if len(micro_indices_local) < 5: continue
            
            micro_indices_global = l0_indices[micro_indices_local]
            
            boards_in_micro_cluster = [final_boards_list[i] for i in micro_indices_global]
            cluster_dict = {str(micro_cid): [b.get('ascii_board', '') for b in boards_in_micro_cluster]}
            tester = TicTacToeClusterHypothesisTester(json.dumps(cluster_dict))
            analysis = tester.analysis_results.get(str(micro_cid), {})
            
            purity = analysis.get('dominant_line_purity', 0)
            
            if purity >= PURITY_THRESHOLD and analysis.get('top_10_line_patterns'):
                top_pattern_info = analysis['top_10_line_patterns'][0]
                (line_index, pattern_tuple), _ = top_pattern_info
                line_name = tester.LINE_NAMES.get(line_index, f"L{line_index}")
                pattern_str = "".join([tester.PIECE_MAP.get(p, '?') for p in pattern_tuple])
                dominant_pattern_key = f"{line_name}:({pattern_str})"
                
                pattern_to_indices[dominant_pattern_key].extend(micro_indices_global)

        for pattern_key, collected_indices in pattern_to_indices.items():
            l0_to_sub_clusters_map[str(l0_cid)].append({
                'indices': np.array(collected_indices),
                'label': pattern_key
            })
            
    # 3. Visualize the final result, now passing the L0 analysis as well
    visualize_hybrid_breakdown(
        reduced_activations,
        l0_labels,
        l0_analysis_results,
        l0_to_sub_clusters_map,
        output_dir,
        layer,
        prompt_style
    )

    if return_components:
        # NEW: Persist a raw sub-cluster metadata JSON so downstream scripts
        # (e.g., manual merge tooling) do not rely on calling the interactive
        # HTML generator just to obtain indices & labels. We store a flattened
        # list of entries similar to what the interactive tool would build.
        try:
            os.makedirs(output_dir, exist_ok=True)
            flat_entries = []
            running_id = 0
            for parent_id, sub_list in l0_to_sub_clusters_map.items():
                for sub in sub_list:
                    idxs = np.array(sub['indices'])
                    if len(idxs) == 0:
                        continue
                    flat_entries.append({
                        'global_id': running_id,
                        'parent_l0': parent_id,
                        'size': int(len(idxs)),
                        'label': sub.get('label', f"{parent_id}_sub"),
                        'indices': idxs.tolist()
                    })
                    running_id += 1
            meta_filename = f"hybrid_merge_subcluster_metadata_{prompt_style}_layer{layer}.json"
            meta_path = os.path.join(output_dir, meta_filename)
            with open(meta_path, 'w') as f:
                json.dump(flat_entries, f, indent=2, cls=NpEncoder)
            print(f"[HybridAnalysis] Saved sub-cluster metadata (auto) to {meta_path}")
        except Exception as e:
            print(f"[HybridAnalysis] Failed to auto-save sub-cluster metadata: {e}")
        return l0_labels, l0_analysis_results, l0_to_sub_clusters_map

# ==================== NEW ADDITIONS (Do NOT modify existing functions) ====================
def apply_massive_plot_style(font_scale: float = 2.8, base_style: str = 'whitegrid'):
    """Massively increase global Matplotlib / Seaborn font sizes for all subsequent plots.

    Call this ONCE near the start of your script (before generating figures) to retroactively
    give every existing plotting function much larger, presentation-quality fonts without
    editing any of those functions.

    Args:
        font_scale: Multiplier passed to seaborn.set_context; >2.5 is usually very large.
        base_style: Seaborn base style (e.g., 'white', 'whitegrid', 'darkgrid').
    """
    import seaborn as _sns
    import matplotlib as _mpl
    _sns.set_theme(style=base_style)
    _sns.set_context("talk", font_scale=font_scale)
    # Explicit rcParams overrides to ensure consistency across all existing code paths.
    _mpl.rcParams.update({
        'axes.titlesize': 34,
        'axes.labelsize': 30,
        'xtick.labelsize': 26,
        'ytick.labelsize': 26,
        'legend.title_fontsize': 28,
        'legend.fontsize': 24,
        'figure.titlesize': 38,
        'savefig.bbox': 'tight'
    })
    # Improve default line / marker visibility for dense scatter plots
    _mpl.rcParams['lines.markersize'] = 10
    _mpl.rcParams['lines.linewidth'] = 2.2
    print(f"[Style] Applied massive plot style (font_scale={font_scale}).")


def visualize_strategic_situation_aggregated(
    reduced_activations: np.ndarray,
    boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str,
    apply_style: bool = False
):
    """NEW aggregated strategic situation visualization with distinct color palette.

    This function COMBINES the granular categories from `visualize_strategic_situation`
    into five high-level, clearly distinguishable strategic states:

        Player Won      (Blue)   -> Combines 'X Won', 'O Won'
        Guaranteed Win  (Green)  -> Combines 'X Guaranteed Win', 'O Guaranteed Win'
        Must Block      (Red)    -> Combines 'X Must Block', 'O Must Block'
        To Play         (Orange) -> Combines 'X to Play', 'O to Play'
        Draw            (Yellow) -> 'Draw'

    We intentionally do NOT modify or replace the existing detailed function to
    preserve backward compatibility; this is an additive alternative.

    Args:
        reduced_activations: 2D numpy array of shape (N, 2) or (N, d) reduced embedding coords.
        boards_list: List of board metadata dicts (as produced elsewhere in pipeline).
        output_dir: Directory to save the plot (will be created if missing).
        layer: Model layer index (for title annotation).
        prompt_style: Prompt modality (e.g., 'text_instruction', 'ascii_board').
        apply_style: If True, automatically call `apply_massive_plot_style()` first.
    """
    if reduced_activations is None or len(reduced_activations) == 0 or not boards_list:
        print("[AggregatedStrategic] No data provided; skipping plot.")
        return

    if apply_style:
        # Apply large fonts on demand without forcing global side-effects for users who prefer not to.
        apply_massive_plot_style()

    print("Generating aggregated strategic situation plot (compressed categories)... 🧭")

    # Local helpers (mirrors logic from detailed function but aggregates labels)
    memoized_scores: dict[tuple, float] = {}

    def get_board_threats(board, player):
        lines = [(0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8), (2,4,6)]
        threat_count = 0
        for line in lines:
            pieces = [board[i] for i in line]
            if pieces.count(player) == 2 and pieces.count(0) == 1:
                threat_count += 1
        return threat_count

    aggregated_categories = []
    for board_data in boards_list:
        board = board_data['board']
        board_tuple = tuple(board)

        # Terminal Cases
        if board_data.get('is_terminal'):
            if board_data.get('winner') in (1, 2):
                aggregated_categories.append('Player Won')
            else:
                aggregated_categories.append('Draw')
            continue

        # Determine current player (same logic as existing detailed variant: counts of 1 vs 2)
        p1_moves = board.count(1)
        p2_moves = board.count(2)
        current_player = 1 if p1_moves == p2_moves else 2

        if board_tuple not in memoized_scores:
            memoized_scores[board_tuple] = _minimax_get_score(board, current_player)[0]
        score = memoized_scores[board_tuple]

        # Guaranteed win (any side) threshold logic reused
        if score > 5 or score < -5:
            aggregated_categories.append('Guaranteed Win')
            continue

        opponent = 2 if current_player == 1 else 1
        if get_board_threats(board, opponent) > 0:
            aggregated_categories.append('Must Block')
        else:
            aggregated_categories.append('To Play')

    # Distinct, high-contrast color palette (colorblind-aware emphasis)
    color_map = {
        'Player Won': '#0057B8',      # Strong blue
        'Guaranteed Win': '#238823',   # Green (accessible)
        'Must Block': '#D62828',       # Bold red
        'To Play': '#F77F00',          # Vivid orange
        'Draw': '#FFD60A'              # Warm yellow
    }

    # Map categories to colors (fallback to black if any unexpected label slips through)
    point_colors = [color_map.get(cat, '#000000') for cat in aggregated_categories]

    import matplotlib.pyplot as _plt
    _plt.figure(figsize=(24, 20))
    # If reduced_activations has >2 dims, project first 2 (caller usually passes 2D already)
    coords = reduced_activations[:, :2]
    _plt.scatter(coords[:, 0], coords[:, 1], c=point_colors, alpha=0.8, s=55, edgecolors='none')

    # Build legend manually to ensure ordering
    from matplotlib.lines import Line2D as _Line2D
    legend_handles = [
        _Line2D([0], [0], marker='o', color='w', label=f"{label} (n={aggregated_categories.count(label)})",
                markerfacecolor=col, markersize=16) for label, col in color_map.items()
    ]
    _plt.legend(handles=legend_handles, title="Strategic Situation (Aggregated)", frameon=True, fancybox=True, borderpad=0.8)

    _plt.title(f'Aggregated Strategic Situation (Layer {layer}, Style: {prompt_style})')
    _plt.xlabel('t-SNE Dimension 1')
    _plt.ylabel('t-SNE Dimension 2')
    _plt.grid(True, alpha=0.25)

    os.makedirs(output_dir, exist_ok=True)
    out_path = os.path.join(output_dir, f'viz_hypothesis_strategic_situation_aggregated_{prompt_style}.png')
    _plt.savefig(out_path)
    _plt.close()
    print(f"[AggregatedStrategic] Saved aggregated strategic plot to {out_path}")

    return {
        'counts': {cat: aggregated_categories.count(cat) for cat in color_map.keys()},
        'output_path': out_path
    }


# ==================== Interactive Post-hoc Merge Tool ====================
def create_interactive_hybrid_merge_tool(
    reduced_activations: np.ndarray,
    l0_labels: np.ndarray,
    l0_analysis_results: Dict[str, Any],
    l0_to_sub_clusters_map: Dict[str, List[Dict[str, Any]]],
    final_boards_list: List[Dict[str, Any]],
    output_dir: str,
    layer: int,
    prompt_style: str,
    filename_prefix: str = "hybrid_merge"
):
    """Generate an interactive HTML (Plotly) plot to manually merge nearby sub-clusters.

    Motivation:
        The hybrid hierarchical-agglomerative process often yields many spatially adjacent
        sub-clusters whose dominant line pattern purities are similar. Automatic merging
        (by pattern key) may split semantically identical regions. This tool produces:

        1. An interactive scatter plot (no convex hull borders) of all sub-cluster points.
           Each sub-cluster has a unique color & hover tooltip (cluster id, size, pattern).
        2. A JSON sidecar listing all available sub-clusters & their metadata.
        3. A simple JS-based UI (in the same HTML) that lets the user:
             - Ctrl/Cmd click legend items (native Plotly filtering) to isolate clusters.
             - Enter a comma-separated list of sub-cluster IDs to merge.
             - Specify a new label string for the merged super-cluster.
             - Click "Merge" (handled by embedded JS) to update an in-browser table.
             - Click "Download Result" to export merged mapping & labels as JSON.

    Output Files:
        {prefix}_interactive_{prompt_style}_layer{layer}.html   (interactive tool)
        {prefix}_subcluster_metadata_{prompt_style}_layer{layer}.json (raw metadata)

    NOTE: This function does NOT alter any Python-side state; merging is performed client-side.
          The produced JSON can later be re-ingested by a separate script (to be written as needed)
          to apply the manual relabel mapping.
    """
    try:
        import plotly.graph_objects as go
        import plotly.express as px
    except ImportError:
        print("Plotly not installed. Please install plotly to use interactive merge tool.")
        return

    os.makedirs(output_dir, exist_ok=True)

    # Assemble sub-cluster entries
    sub_entries = []  # Each: {id, parent, indices, size, label, centroid:[x,y], pattern(optional)}
    all_points_cluster_id = np.full(reduced_activations.shape[0], fill_value=-1)

    # Helper to compute centroid fast
    def _centroid(idxs: np.ndarray):
        pts = reduced_activations[idxs]
        return pts.mean(axis=0).tolist()

    running_id = 0
    for parent_id, sub_list in l0_to_sub_clusters_map.items():
        for sub in sub_list:
            idxs = np.array(sub['indices'])
            if len(idxs) == 0:
                continue
            # Assign a global numeric id
            gid = running_id
            running_id += 1
            all_points_cluster_id[idxs] = gid
            label = sub.get('label', f"{parent_id}_sub")
            entry = {
                'global_id': gid,
                'parent_l0': parent_id,
                'size': int(len(idxs)),
                'label': label,
                'centroid': _centroid(idxs),
                'indices': idxs.tolist()
            }
            sub_entries.append(entry)

    if not sub_entries:
        print("No sub-clusters available for interactive merge tool.")
        return

    # Build DataFrame-like dict for plotting
    point_x = reduced_activations[:, 0]
    point_y = reduced_activations[:, 1]
    cluster_ids = all_points_cluster_id

    # Color only assigned to points that belong to a sub-cluster; noise stays light grey
    unique_ids = sorted([e['global_id'] for e in sub_entries])
    palette = px.colors.qualitative.Alphabet
    color_map = {cid: palette[cid % len(palette)] for cid in unique_ids}

    colors = [color_map[cid] if cid in color_map else '#DDDDDD' for cid in cluster_ids]
    hover_text = []
    # Pre-build map from gid -> entry for fast lookup
    entry_map = {e['global_id']: e for e in sub_entries}
    for cid in cluster_ids:
        if cid in entry_map:
            meta = entry_map[cid]
            hover_text.append(
                f"SubID: {meta['global_id']}<br>Parent: {meta['parent_l0']}<br>Label: {meta['label']}<br>Size: {meta['size']}"  # noqa: E501
            )
        else:
            hover_text.append("Unassigned / Noise")

    fig = go.Figure(
        data=[go.Scattergl(
            x=point_x,
            y=point_y,
            mode='markers',
            marker=dict(color=colors, size=6, opacity=0.8),
            text=hover_text,
            hoverinfo='text'
        )]
    )
    fig.update_layout(
        title=f"Hybrid Sub-Clusters (Layer {layer}, Style: {prompt_style}) - Interactive Merge Tool",
        xaxis_title="t-SNE / Reduced Dim 1",
        yaxis_title="t-SNE / Reduced Dim 2",
        template="plotly_white",
        legend=dict(itemsizing='constant')
    )

    # We will embed a legend manually listing sub clusters since Scattergl single trace has no separate legend items.
    # Provide a HTML table listing each sub-cluster with its color.
    legend_rows = []
    for e in sub_entries:
        cid = e['global_id']
        color = color_map.get(cid, '#CCCCCC')
        legend_rows.append(
            f"<tr><td style='padding:2px 6px;background:{color};width:24px'>&nbsp;</td>"
            f"<td>{cid}</td><td>{e['parent_l0']}</td><td>{e['size']}</td><td>{e['label']}</td></tr>"
        )
    legend_table_html = (
        "<table border='1' cellspacing='0' cellpadding='2' style='font-size:12px;border-collapse:collapse'>"
        "<thead><tr><th>Color</th><th>SubID</th><th>L0</th><th>Size</th><th>Label</th></tr></thead><tbody>"
        + "".join(legend_rows) + "</tbody></table>"
    )

    # JavaScript for merging logic (client side)
    merge_js = r"""
    <script>
    const subMeta = JSON.parse(document.getElementById('subcluster-metadata-json').textContent);
    let merges = []; // {new_label, members:[ids], size_sum}

    function performMerge(){
        const idsStr = document.getElementById('merge_ids').value.trim();
        const newLabel = document.getElementById('merge_label').value.trim();
        if(!idsStr || !newLabel){
            alert('Please enter comma-separated SubIDs and a New Label.');
            return;
        }
        const ids = idsStr.split(',').map(s=>parseInt(s.trim())).filter(n=>!isNaN(n));
        if(ids.length === 0){
            alert('No valid SubIDs parsed.');
            return;
        }
        // Validate IDs exist
        const valid = ids.every(id => subMeta.some(e => e.global_id === id));
        if(!valid){
            alert('One or more SubIDs not found.');
            return;
        }
        const sizeSum = subMeta.filter(e => ids.includes(e.global_id)).reduce((a,b)=>a+b.size,0);
        // OVERWRITE LOGIC: remove any existing merge entries that contain *any* of these ids
        merges = merges.filter(m => !m.members.some(x => ids.includes(x)));
        merges.push({new_label: newLabel, members: ids, size_sum: sizeSum});
        renderMergeTable();
        document.getElementById('merge_ids').value='';
        document.getElementById('merge_label').value='';
    }

    function renderMergeTable(){
        const tbody = document.getElementById('merge-results-body');
        tbody.innerHTML = '';
        merges.forEach((m,i)=>{
            const tr = document.createElement('tr');
            tr.innerHTML = `<td>${i}</td><td>${m.new_label}</td><td>${m.members.join(', ')}</td><td>${m.size_sum}</td>` +
                           `<td><button onclick=\"deleteMerge(${i})\">Delete</button></td>`;
            tbody.appendChild(tr);
        });
    }

    function deleteMerge(idx){
        if(idx < 0 || idx >= merges.length) return;
        merges.splice(idx,1);
        renderMergeTable();
    }

    function downloadJSON(){
        const blob = new Blob([JSON.stringify({merges: merges}, null, 2)], {type: 'application/json'});
        const url = URL.createObjectURL(blob);
        const a = document.createElement('a');
        a.href = url; a.download = 'merged_subclusters.json'; a.click();
        URL.revokeObjectURL(url);
    }

    function importJSON(evt){
        const file = evt.target.files[0];
        if(!file) return;
        const reader = new FileReader();
        reader.onload = (e)=>{
            try {
                const data = JSON.parse(e.target.result);
                if(!data.merges || !Array.isArray(data.merges)){
                    alert('File missing "merges" array.');
                    return;
                }
                // Validate members exist in current metadata; skip invalid entries
                const filtered = [];
                data.merges.forEach(m => {
                    if(!m.members || !Array.isArray(m.members)) return;
                    const members = m.members.filter(id => subMeta.some(e=>e.global_id===id));
                    if(members.length === 0) return;
                    const sizeSum = subMeta.filter(e => members.includes(e.global_id)).reduce((a,b)=>a+b.size,0);
                    filtered.push({new_label: m.new_label || 'Imported', members: members, size_sum: sizeSum});
                });
                merges = filtered;
                renderMergeTable();
                document.getElementById('import_status').textContent = `Imported ${filtered.length} merge groups.`;
            } catch(err){
                alert('Failed to parse JSON: '+err);
            }
        };
        reader.readAsText(file);
    }
    </script>
    """

    html_path = os.path.join(
        output_dir,
        f"{filename_prefix}_interactive_{prompt_style}_layer{layer}.html"
    )
    meta_path = os.path.join(
        output_dir,
        f"{filename_prefix}_subcluster_metadata_{prompt_style}_layer{layer}.json"
    )

    # Save metadata JSON separately (for offline reuse)
    try:
        with open(meta_path, 'w') as f:
            json.dump(sub_entries, f, indent=2, cls=NpEncoder)
        print(f"Saved sub-cluster metadata to {meta_path}")
    except Exception as e:
        print(f"Failed to write metadata JSON: {e}")

    # Compose full HTML
    fig_html = fig.to_html(include_plotlyjs='cdn', full_html=False)
    full_html = f"""
<!DOCTYPE html>
<html lang='en'>
<head>
  <meta charset='UTF-8'/>
  <title>Hybrid Sub-Cluster Merge Tool</title>
  <style>
    body {{ font-family: Arial, sans-serif; }}
    .panel {{ margin: 10px 0; padding:10px; border:1px solid #aaa; border-radius:6px; }}
    input[type=text] {{ width: 300px; }}
    table {{ margin-top:10px; }}
  </style>
</head>
<body>
  <h2>Hybrid Sub-Cluster Interactive Merge Tool</h2>
  <p>Layer {layer} | Prompt Style: {prompt_style}</p>
  <div class='panel'>
    <h3>Instructions</h3>
    <ol>
      <li>Inspect nearby sub-clusters (hover for metadata). Similar semantics + purity can be merged.</li>
      <li>Use the table below to note SubIDs you wish to merge.</li>
    <li>Enter comma-separated SubIDs and a new label, then click <b>Merge</b>.</li>
    <li>Re-merging any SubID overwrites its previous assignment automatically.</li>
    <li>Use the Delete button in the table to remove an entry.</li>
    <li>Repeat as needed; when finished, click <b>Download Result</b> to save your manual merges.</li>
      <li>Apply merges later in Python by loading the JSON and remapping points.</li>
    </ol>
  </div>
  <div class='panel'>
    <h3>Merge Editor</h3>
    <label>SubIDs (comma separated): <input type='text' id='merge_ids' placeholder='e.g. 0,2,5'/></label><br/>
    <label>New Label: <input type='text' id='merge_label' placeholder='e.g. Fork Threat Cluster'/></label><br/>
    <button onclick='performMerge()'>Merge</button>
    <button onclick='downloadJSON()'>Download Result</button>
    <label style='margin-left:12px;'>Import Existing: <input type='file' accept='.json' onchange='importJSON(event)'/></label>
    <span id='import_status' style='margin-left:10px;font-size:12px;color:#333;'></span>
        <table border='1' cellspacing='0' cellpadding='4' style='border-collapse:collapse; font-size:12px;'>
            <thead><tr><th>#</th><th>New Label</th><th>Members</th><th>Total Size</th><th>Actions</th></tr></thead>
      <tbody id='merge-results-body'></tbody>
    </table>
  </div>
  <div class='panel'>
    <h3>Sub-Cluster Legend</h3>
    {legend_table_html}
  </div>
  <div class='panel'>
    <h3>Interactive Plot</h3>
    {fig_html}
  </div>
  <script id='subcluster-metadata-json' type='application/json'>{json.dumps(sub_entries)}</script>
  {merge_js}
</body>
</html>
"""

    try:
        with open(html_path, 'w') as f:
            f.write(full_html)
        print(f"Interactive merge tool saved to {html_path}")
    except Exception as e:
        print(f"Failed to write interactive HTML: {e}")



def visualize_board_state_hypotheses(model, task, sae, clusters, layer, display_name, prompt_style='text_instruction', hypotheses_to_run: Optional[set] = None):
    """
    Generates multiple t-SNE plots to test various hypotheses.
    
    Args:
        hypotheses_to_run (Optional[set]): A set of strings specifying which plots to run.
            e.g., {'game_turn', 'trajectories'}. If None, all plots are generated.
            Keys: 'full_board', 'game_turn', 'threat', 'symmetry', 'progression', 
                  'turn_by_turn', 'winner', 'best_vs_legal', 'trajectories', 'kmeans_dbscan', 'strategy'.
    """
    print(f"\n--- VISUALIZING BOARD STATE HYPOTHESES for layer {layer} (Prompt Style: {prompt_style}) ---")
    if not clusters: return None, None
    
    # Cache logic for tsne
    cache_dir = "tsne_cache"
    os.makedirs(cache_dir, exist_ok=True)
    sanitized_name = sanitize_model_name(display_name)
    cache_file_base = f"{sanitized_name}_layer_{layer}_{prompt_style}"
    cache_path_npz = os.path.join(cache_dir, f"{cache_file_base}.npz")
    cache_path_json = os.path.join(cache_dir, f"{cache_file_base}_boards.json")
    # If cache exists, load it and skip the expensive computation
    if os.path.exists(cache_path_npz) and os.path.exists(cache_path_json):
        print(f"✅ Loading cached t-SNE results from {cache_dir}...")
        with np.load(cache_path_npz) as data:
            reduced_activations = data['reduced_activations']
        with open(cache_path_json, 'r') as f:
            final_boards_list = json.load(f)
        # We still need to generate the plots, but we've skipped the main computation
        # If the caller did not specify which hypotheses to run (None) we must
        # still materialize the default set here, otherwise membership tests
        # like `if 'full_board' in hypotheses_to_run` will raise a TypeError.
        if hypotheses_to_run is None:
            hypotheses_to_run = {
                'full_board', 'game_turn', 'threat', 'symmetry', 'progression',
                'turn_by_turn', 'winner', 'best_vs_legal', 'trajectories', 'kmeans_dbscan', 'strategy'
            }
        # When using cached results, helper variables for trajectory plotting
        # (x_win_board, o_win_board, draw_board, reconstruct_board) were not
        # constructed previously (they normally live in the non-cache branch).
        # Define them here so the 'trajectories' visualization does not crash.
        if 'reconstruct_board' not in locals():
            def reconstruct_board(move_seq):
                board = [0] * 9
                for i, move_token in enumerate(move_seq):
                    player = 1 if i % 2 == 0 else 2
                    square_idx = (move_token - 1) % 9
                    board[square_idx] = player
                return tuple(board)
        if 'x_win_board' not in locals():
            x_win_board = next((b for b in final_boards_list if b.get('is_terminal') and b.get('winner') == 1), None)
        if 'o_win_board' not in locals():
            o_win_board = next((b for b in final_boards_list if b.get('is_terminal') and b.get('winner') == 2), None)
        if 'draw_board' not in locals():
            draw_board = next((b for b in final_boards_list if b.get('is_terminal') and b.get('winner') == 0), None)
    else:
        print("No cache found. Running t-SNE computation...")
        # If no specific hypotheses are requested, run all of them.
        if hypotheses_to_run is None:
            hypotheses_to_run = {
                'full_board', 'game_turn', 'threat', 'symmetry', 'progression', 
                'turn_by_turn', 'winner', 'best_vs_legal', 'trajectories', 'kmeans_dbscan', 'strategy'
            }
        
        print(f"Running the following visualizations: {hypotheses_to_run}")

        largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
        cluster_indices = clusters[largest_cluster_id]
        unique_boards_data = list({tuple(b['board']): b for b in task.dataset}.values())

        # --- Step 1: Identify ALL boards needed for plotting ---
        sampled_boards_map = {tuple(b['board']): b for b in unique_boards_data}

        def reconstruct_board(move_seq):
            board = [0] * 9
            for i, move_token in enumerate(move_seq):
                player = 1 if i % 2 == 0 else 2
                square_idx = (move_token - 1) % 9
                board[square_idx] = player
            return tuple(board)

        full_dataset_map = {tuple(b['board']): b for b in unique_boards_data}
        x_win_board = next((b for b in unique_boards_data if b['is_terminal'] and b['winner'] == 1), None)
        o_win_board = next((b for b in unique_boards_data if b['is_terminal'] and b['winner'] == 2), None)
        draw_board = next((b for b in unique_boards_data if b['is_terminal'] and b['winner'] == 0), None)
        trajectories_to_find = [x_win_board, o_win_board, draw_board]
        trajectory_boards_map = {}

        for terminal_board in trajectories_to_find:
            if not terminal_board or not terminal_board.get('move_sequences') or not terminal_board['move_sequences'][0]:
                continue
            move_sequence = terminal_board['move_sequences'][0] # Use the first sequence
            path_board_tuples = [reconstruct_board(move_sequence[:i]) for i in range(len(move_sequence) + 1)]
            for board_tuple in path_board_tuples:
                if board_tuple not in trajectory_boards_map:
                    trajectory_boards_map[board_tuple] = full_dataset_map.get(board_tuple)

        final_boards_map = {**sampled_boards_map, **trajectory_boards_map}
        final_boards_list = [b for b in final_boards_map.values() if b is not None]
        
        print(f"Total points for t-SNE: {len(final_boards_list)}")

        # --- Step 2: Run t-SNE ONCE on the final combined set of boards ---
        prompts = [task.get_prompt(b, style=prompt_style) for b in final_boards_list]
        reconstructed_activations = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, 16)

        print("Running t-SNE on combined set... (this may take a moment)")
        tsne = TSNE(n_components=2, perplexity=50, learning_rate=200, n_iter=1000, random_state=42, n_jobs=-1)
        reduced_activations = tsne.fit_transform(reconstructed_activations.detach().numpy())
        
        # --- NEW: Save the results to cache ---
        print(f"💾 Saving t-SNE results to cache in {cache_dir}...")
        np.savez_compressed(cache_path_npz, reduced_activations=reduced_activations)
        with open(cache_path_json, 'w') as f:
            json.dump(final_boards_list, f)


    board_to_coord_map = {tuple(b['board']): reduced_activations[i] for i, b in enumerate(final_boards_list)}
    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    os.makedirs(output_dir, exist_ok=True)

    # --- Step 3: Generate all requested plots ---

    if 'full_board' in hypotheses_to_run:
        print("Generating full board state plot with labels...")
        plt.figure(figsize=(20, 16))
        plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], alpha=0.5, s=15, c='lightblue')
        num_labels = 40
        if len(final_boards_list) > num_labels:
            indices_to_label = random.sample(range(len(final_boards_list)), num_labels)
            for i in indices_to_label:
                label_key = 'ascii_board' if prompt_style == 'ascii_board' else 'text_instruction'
                board_text = final_boards_list[i].get(label_key, str(final_boards_list[i]['board']))
                if label_key == 'text_instruction':
                    board_text = (board_text[:75] + '...') if len(board_text) > 75 else board_text
                plt.text(reduced_activations[i, 0], reduced_activations[i, 1], board_text, ha='center', va='bottom', fontsize=7, 
                         bbox=dict(boxstyle="round,pad=0.3", fc="yellow", ec="black", lw=1, alpha=0.7))
        plt.title(f't-SNE of Full Board State Representations (Layer {layer}, Style: {prompt_style})')
        plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'viz_full_board_tsne_{prompt_style}.png')); plt.close()

    if 'game_turn' in hypotheses_to_run:
        print("Generating game turn plot...")
        categories = []
        for board in final_boards_list:
            if board['is_terminal']: categories.append('Game Over')
            elif board['board'].count(1) == board['board'].count(2): categories.append("X's Turn")
            else: categories.append("O's Turn")
        category_map = {"X's Turn": 'blue', "O's Turn": 'red', 'Game Over': 'grey'}
        colors = [category_map[c] for c in categories]
        plt.figure(figsize=(20, 16))
        plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.6, s=20)
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in category_map.items()]
        plt.legend(handles=handles, title="Board State"); plt.title(f't-SNE of Board States by Game Turn (Layer {layer}, Style: {prompt_style})')
        plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_game_turn_{prompt_style}.png')); plt.close()

    if 'threat' in hypotheses_to_run:
        print("Generating threat detection plot...")
        def check_threat(board, player):
            lines = [[0,1,2],[3,4,5],[6,7,8],[0,3,6],[1,4,7],[2,5,8],[0,4,8],[2,4,6]]
            for line in lines:
                pieces = [board[i] for i in line]
                if pieces.count(player) == 2 and pieces.count(0) == 1: return True
            return False
        threat_categories = ["X has threat" if check_threat(b['board'], 1) else "O has threat" if check_threat(b['board'], 2) else "No threat" for b in final_boards_list]
        threat_map = {"X has threat": 'cyan', "O has threat": 'magenta', "No threat": 'lightgrey'}
        threat_colors = [threat_map[c] for c in threat_categories]
        plt.figure(figsize=(20, 16)); plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=threat_colors, alpha=0.7, s=20)
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in threat_map.items()]
        plt.legend(handles=handles, title="Threat State"); plt.title(f't-SNE of Board States by Threat Detection (Layer {layer}, Style: {prompt_style})')
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_threat_detection_{prompt_style}.png')); plt.close()

    if 'symmetry' in hypotheses_to_run:
        print("Generating board symmetry plot...")
        try: symmetry_ids_as_tuples = [tuple(board['canonical_symmetry_id']) for board in final_boards_list]
        except TypeError: symmetry_ids_as_tuples = [board['canonical_symmetry_id'] for board in final_boards_list]
        id_counts = Counter(symmetry_ids_as_tuples)
        unique_symmetries_count = len(id_counts)
        MANAGEABLE_COUNT = 20; OTHER_COLOR = 'lightgrey'
        if unique_symmetries_count > MANAGEABLE_COUNT:
            top_ids = [item[0] for item in id_counts.most_common(MANAGEABLE_COUNT)]
            cmap = plt.get_cmap('tab20', MANAGEABLE_COUNT)
            symmetry_map = {sid: cmap(i) for i, sid in enumerate(top_ids)}; symmetry_map['Other'] = OTHER_COLOR
            symmetry_colors = [symmetry_map.get(sid, symmetry_map['Other']) for sid in symmetry_ids_as_tuples]
            handles = [plt.Line2D([0], [0], marker='o', color='w', label=f"Symmetry Group {i+1}", markersize=10, markerfacecolor=col) for i, col in enumerate(cmap.colors)]
            handles.append(plt.Line2D([0], [0], marker='o', color='w', label=f"Other ({unique_symmetries_count - MANAGEABLE_COUNT} groups)", markersize=10, markerfacecolor=OTHER_COLOR))
        else:
            unique_symmetries = sorted(list(id_counts.keys()))
            cmap = plt.get_cmap('tab20', len(unique_symmetries))
            symmetry_map = {sid: cmap(i) for i, sid in enumerate(unique_symmetries)}
            symmetry_colors = [symmetry_map[sid] for sid in symmetry_ids_as_tuples]
            handles = [plt.Line2D([0], [0], marker='o', color='w', label=f"Symmetry Group {i+1}", markersize=10, markerfacecolor=col) for i, col in enumerate(symmetry_map.values())]
        plt.figure(figsize=(20, 16)); plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=symmetry_colors, alpha=0.7, s=20)
        plt.legend(handles=handles, title="Canonical Symmetry ID"); plt.title(f't-SNE of Board States by Canonical Symmetry (Layer {layer}, Style: {prompt_style})')
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_symmetry_{prompt_style}.png')); plt.close()

    if 'progression' in hypotheses_to_run:
        print("Generating game progression plot...")
        progression_categories = []
        for board in final_boards_list:
            piece_count = 9 - board['board'].count(0)
            if piece_count <= 2: progression_categories.append("Early Game (0-2 pieces)")
            elif piece_count <= 5: progression_categories.append("Mid Game (3-5 pieces)")
            else: progression_categories.append("Late Game (6+ pieces)")
        progression_map = {"Early Game (0-2 pieces)": '#2ca02c', "Mid Game (3-5 pieces)": '#ff7f0e', "Late Game (6+ pieces)": '#9467bd'}
        progression_colors = [progression_map[c] for c in progression_categories]
        plt.figure(figsize=(20, 16)); plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=progression_colors, alpha=0.7, s=20)
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in progression_map.items()]
        plt.legend(handles=handles, title="Game Progression"); plt.title(f't-SNE of Board States by Game Progression (Layer {layer}, Style: {prompt_style})')
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_game_progression_{prompt_style}.png')); plt.close()

    if 'turn_by_turn' in hypotheses_to_run:
        print("Generating turn-by-turn game progression plot... 🔢")
        turn_categories = []
        for board in final_boards_list:
            if board['is_terminal']: turn_categories.append("Game Over")
            else:
                pc = 9 - board['board'].count(0)
                if pc == 0: turn_categories.append("Turn 0")
                elif pc in [1, 2]: turn_categories.append("Turns 1-2")
                elif pc in [3, 4]: turn_categories.append("Turns 3-4")
                elif pc in [5, 6]: turn_categories.append("Turns 5-6")
                elif pc in [7, 8]: turn_categories.append("Turns 7-8")
                else: turn_categories.append("Game Over") # Should cover turn 9

        unique_turns = sorted(list(set(turn_categories)), key=lambda x: 10 if 'Over' in x else (0 if '0' in x else int(x.split(' ')[-1].split('-')[0])))
        cmap = plt.get_cmap('plasma', len(unique_turns))
        turn_map = {turn: cmap(i) for i, turn in enumerate(unique_turns)}
        turn_colors = [turn_map[c] for c in turn_categories]

        plt.figure(figsize=(20, 16))
        plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=turn_colors, alpha=0.7, s=20)
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in turn_map.items()]
        plt.legend(handles=handles, title="Game Turn")
        base_title = f't-SNE of Board States by Game Turn (Layer {layer}, Style: {prompt_style})'
        plt.title(base_title)
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_turn_by_turn_paired_{prompt_style}.png')); plt.close()

        # Automatically create a relabeled version of this new plot with best moves
        create_relabeled_plot(
            reduced_activations=reduced_activations, sampled_boards=final_boards_list,
            colors=turn_colors, base_title=base_title, label_type='best',
            output_dir=output_dir, layer=layer
        )
        
    if 'strategy' in hypotheses_to_run:
        visualize_strategic_situation(reduced_activations, final_boards_list, output_dir, layer, prompt_style)
        visualize_strategic_situation_aggregated(reduced_activations, final_boards_list, output_dir, layer, prompt_style)
        visualize_strategic_situation_aggregated_game_theoretic(reduced_activations, final_boards_list, output_dir, layer, prompt_style)

    if 'winner' in hypotheses_to_run:
        print("Generating winner plot... 🏆")
        winner_categories = ['X Wins' if b.get('winner') == 1 else 'O Wins' if b.get('winner') == 2 else 'Draw / In-Progress' for b in final_boards_list]
        winner_map = {'X Wins': '#FFD700', 'O Wins': '#FF8C00', 'Draw / In-Progress': '#D3D3D3'}
        winner_colors = [winner_map[c] for c in winner_categories]
        plt.figure(figsize=(20, 16)); plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=winner_colors, alpha=0.7, s=20)
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in winner_map.items()]
        plt.legend(handles=handles, title="Game Outcome"); plt.title(f't-SNE of Board States by Winner (Layer {layer}, Style: {prompt_style})')
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_winner_{prompt_style}.png')); plt.close()

    if 'best_vs_legal' in hypotheses_to_run:
        print("Generating Best vs. Legal move plot...")
        TARGET_SQUARE_INDEX = 4; TARGET_SQUARE_LABEL = 5
        plot_data = []
        for board in final_boards_list:
            if board['is_terminal']: plot_data.append({'label': 'Game Over'}); continue
            is_legal = board['board'][TARGET_SQUARE_INDEX] == 0
            if not is_legal: plot_data.append({'label': 'Illegal'}); continue
            is_player1_turn = board['board'].count(1) == board['board'].count(2)
            target_move_token = TARGET_SQUARE_LABEL if is_player1_turn else TARGET_SQUARE_LABEL + 9
            is_best = target_move_token in board.get('best_moves', [])
            plot_data.append({'label': 'Best & Legal' if is_best else 'Legal, Not Best'})
        category_map_styles = {'Best & Legal':{'color':'gold','marker':'*','size':150}, 'Legal, Not Best':{'color':'limegreen','marker':'.','size':30}, 'Illegal':{'color':'tomato','marker':'x','size':30}, 'Game Over':{'color':'lightgrey','marker':'o','size':20}}
        plt.figure(figsize=(20, 16))
        for cat, style in category_map_styles.items():
            indices = [i for i, data in enumerate(plot_data) if data['label'] == cat]
            if indices: plt.scatter(reduced_activations[indices, 0], reduced_activations[indices, 1], c=style['color'], marker=style['marker'], s=style['size'], label=cat, alpha=0.8)
        plt.legend(title=f"Status of Move in Square {TARGET_SQUARE_LABEL}", markerscale=2); plt.title(f't-SNE by Best vs. Legal Status of Move (Layer {layer}, Style: {prompt_style})')
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_best_vs_legal_{TARGET_SQUARE_LABEL}_{prompt_style}.png')); plt.close()


    if 'trajectories' in hypotheses_to_run:
        print("Generating game trajectory plot with fewer, clearer paths... 🗺️")
        progression_map = {"Early Game (0-2 pieces)": '#2ca02c', "Mid Game (3-5 pieces)": '#ff7f0e', "Late Game (6+ pieces)": '#9467bd'}
        progression_colors = []
        for board in final_boards_list:
            piece_count = 9 - board['board'].count(0)
            if piece_count <= 2: progression_colors.append(progression_map["Early Game (0-2 pieces)"])
            elif piece_count <= 5: progression_colors.append(progression_map["Mid Game (3-5 pieces)"])
            else: progression_colors.append(progression_map["Late Game (6+ pieces)"])

        fig, ax = plt.subplots(figsize=(20, 16))
        ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=progression_colors, alpha=0.3, s=20, zorder=1)
        prog_handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in progression_map.items()]

        empty_board_tuple = (0, 0, 0, 0, 0, 0, 0, 0, 0)
        
        start_handle = None
        if empty_board_tuple in board_to_coord_map:
            x0, y0 = board_to_coord_map[empty_board_tuple]
            ax.scatter(x0, y0, c='black', marker='*', s=600, zorder=10, ec='white', lw=1.0)
            start_handle = plt.Line2D([0], [0], marker='*', color='w', label='Game Start', markersize=20, markerfacecolor='black', markeredgecolor='white')

        # Define and plot a SINGLE, clear trajectory for each outcome
        trajectories_to_plot = {
            'X Win Trajectory': (x_win_board, 'deepskyblue'),
            'O Win Trajectory': (o_win_board, 'magenta'),
            'Draw Trajectory': (draw_board, "#FF0000")  # High-visibility orange
        }
        plotted_trajectories = {}

        for name, (terminal_board, color) in trajectories_to_plot.items():
            if not terminal_board or not terminal_board.get('move_sequences'): continue
            
            # Use only the FIRST move sequence for a much cleaner plot
            move_sequence = terminal_board.get('move_sequences', [])[0]
            if not move_sequence: continue

            path_tuples = [reconstruct_board(move_sequence[:i]) for i in range(len(move_sequence) + 1)]
            for i in range(len(path_tuples) - 1):
                parent_tuple, child_tuple = path_tuples[i], path_tuples[i+1]
                if parent_tuple in board_to_coord_map and child_tuple in board_to_coord_map:
                    x1, y1 = board_to_coord_map[parent_tuple]
                    x2, y2 = board_to_coord_map[child_tuple]
                    ax.annotate("", xy=(x2, y2), xytext=(x1, y1), zorder=2, arrowprops=dict(arrowstyle="->", color=color, lw=2.0, alpha=0.8, mutation_scale=25))
                    plotted_trajectories[name] = color

        # Finalize plot with a combined legend
        traj_handles = [plt.Line2D([0], [0], color=color, lw=2.5, label=name) for name, color in plotted_trajectories.items()]
        all_handles = prog_handles + traj_handles
        if start_handle: all_handles.insert(0, start_handle)

        ax.legend(handles=all_handles, title="Legend"); ax.set_title(f't-SNE of Game States with Labeled Game Trajectories (Layer {layer}, Style: {prompt_style})')
        ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True)
        plt.savefig(os.path.join(output_dir, f'viz_hypothesis_game_trajectories_{prompt_style}.png')); plt.close(fig)

    # --- Clustering Analysis ---
    if 'kmeans_dbscan' in hypotheses_to_run:
        print("\n--- Running K-Means Clustering ---")
        n_clusters_kmeans = 18
        kmeans = KMeans(n_clusters=n_clusters_kmeans, random_state=42, n_init='auto').fit(reduced_activations)
        run_and_visualize_clustering_analysis('kmeans', kmeans.labels_, reduced_activations, final_boards_list, output_dir, layer, prompt_style)
        
        if kmeans.labels_ is not None:
             tester = TicTacToeClusterHypothesisTester(os.path.join(output_dir, f'kmeans_clusters_{prompt_style}.json'))
             analysis_results = tester.analysis_results
             if analysis_results:
                 visualize_kmeans_hypothesis_normalized(reduced_activations, kmeans, analysis_results, output_dir, layer, prompt_style)
    
        print("\n--- Running DBSCAN Clustering ---")
        # *** CHANGE 1: Adjusted DBSCAN parameters for better clustering ***
        dbscan = DBSCAN(eps=5.0, min_samples=200).fit(reduced_activations)
        run_and_visualize_clustering_analysis('dbscan', dbscan.labels_, reduced_activations, final_boards_list, output_dir, layer, prompt_style)
    
    print(f"All requested hypothesis visualizations for style '{prompt_style}' saved to {output_dir}")
    # Note: We return None for kmeans_results now as it's handled internally
    return reduced_activations, final_boards_list


# def visualize_invariance_with_random_chars(model, task, sae, clusters, layer, display_name, prompt_style='text_instruction'):
#     """
#     Tests if the board representation is invariant to the tokens used for X and O.
#     This version is parameterized by prompt style.
#     """
#     print(f"\n--- VISUALIZING REPRESENTATION INVARIANCE for layer {layer} (Prompt Style: {prompt_style}) ---")
#     if not clusters: return

#     largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
#     cluster_indices = clusters[largest_cluster_id]
#     unique_boards_data = list({tuple(b['board']): b for b in task.dataset}.values())
    
#     sample_size = min(500, len(unique_boards_data))
#     boards_to_test = random.sample(unique_boards_data, sample_size)

#     mappings = {'Original (X, O)': ('X', 'O'), 'Random Set 1 (A, B)': ('A', 'B'), 'Random Set 2 (P, Q)': ('P', 'Q')}
#     all_prompts = []
#     plot_labels = []

#     for name, (p1_char, p2_char) in mappings.items():
#         def get_remapped_prompt(board_data, p1, p2):
#             # CORRECTED: Use prompt_style to get the base text
#             base_text = board_data.get(prompt_style, '')
#             # Replace the standard player characters with the new ones
#             return base_text.replace('X', p1).replace('O', p2)

#         for board in boards_to_test:
#             all_prompts.append(get_remapped_prompt(board, p1_char, p2_char))
#             plot_labels.append(name)

#     all_activations = get_reconstructed_activations_batched(model, sae, all_prompts, cluster_indices, 16)

#     tsne = TSNE(n_components=2, perplexity=50, learning_rate=200, n_iter=1000, random_state=42)
#     reduced_activations = tsne.fit_transform(all_activations.detach().numpy())

#     plt.figure(figsize=(20, 16))
#     colors = {'Original (X, O)': 'red', 'Random Set 1 (A, B)': 'blue', 'Random Set 2 (P, Q)': 'green'}
    
#     for label, color in colors.items():
#         indices = [i for i, l in enumerate(plot_labels) if l == label]
#         plt.scatter(reduced_activations[indices, 0], reduced_activations[indices, 1], c=color, label=label, alpha=0.7, s=30)
    
#     for i in range(sample_size):
#         indices = [i, i + sample_size, i + (2 * sample_size)]
#         points = reduced_activations[indices]
#         plt.plot(points[:, 0], points[:, 1], color='black', alpha=0.2, lw=0.5)

#     plt.legend(title="Token Mapping")
#     plt.title(f't-SNE of Board Representations with Remapped Tokens (Layer {layer}, Style: {prompt_style})')
#     plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
    
#     output_dir = f"visualizations/{sanitize_model_name(display_name)}/layer_{layer}"
#     os.makedirs(output_dir, exist_ok=True)
#     # CORRECTED: Add prompt_style to filename
#     plt.savefig(os.path.join(output_dir, f'viz_invariance_remapped_tokens_{prompt_style}.png')); plt.close()
#     print(f"Token invariance visualization for style '{prompt_style}' saved to {output_dir}")



def visualize_invariance_with_random_chars(model, task, sae, clusters, layer, display_name, prompt_style='text_instruction'):
    """Tests if board representation is invariant to tokens and visualizes higher-order concepts."""
    print(f"\n--- VISUALIZING INVARIANCE for layer {layer} (Style: {prompt_style}) ---")
    if not clusters: return

    # ... (Data sampling part is the same) ...
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]
    unique_boards_data = list({tuple(b['board']): b for b in task.dataset}.values())
    sample_size = min(500, len(unique_boards_data))
    boards_to_test = random.sample(unique_boards_data, sample_size)
    
    # ... (Prompt generation and t-SNE is the same) ...
    mappings = {'Original (X, O)': ('X', 'O'), 'Random Set 1 (A, B)': ('A', 'B'), 'Random Set 2 (P, Q)': ('P', 'Q')}
    all_prompts, plot_labels = [], []
    for name, (p1_char, p2_char) in mappings.items():
        for board in boards_to_test:
            base_text = board.get(prompt_style, '').replace('X', p1_char).replace('O', p2_char)
            all_prompts.append(base_text)
            plot_labels.append(name)
    all_activations = get_reconstructed_activations_batched(model, sae, all_prompts, cluster_indices, 16)
    tsne = TSNE(n_components=2, perplexity=50, random_state=42)
    reduced_activations = tsne.fit_transform(all_activations.detach().numpy())
    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    os.makedirs(output_dir, exist_ok=True)

    # --- Plot 1: Standard Invariance Plot ---
    plt.figure(figsize=(20, 16))
    colors = {'Original (X, O)': 'red', 'Random Set 1 (A, B)': 'blue', 'Random Set 2 (P, Q)': 'green'}
    for label, color in colors.items():
        indices = [i for i, l in enumerate(plot_labels) if l == label]
        plt.scatter(reduced_activations[indices, 0], reduced_activations[indices, 1], c=color, label=label, alpha=0.7, s=30)
    for i in range(sample_size):
        indices = [i, i + sample_size, i + (2 * sample_size)]
        plt.plot(reduced_activations[indices, 0], reduced_activations[indices, 1], color='black', alpha=0.2, lw=0.5)
    plt.legend(title="Token Mapping")
    plt.title(f'Invariance to Remapped Tokens (Layer {layer}, Style: {prompt_style})')
    plt.savefig(os.path.join(output_dir, f'viz_invariance_remapped_tokens_{prompt_style}.png')); plt.close()

    # --- Plot 2: Invariance colored by GAME TURN ---
    turn_categories = []
    for board in boards_to_test:
        if board['is_terminal']: turn_categories.append('Game Over')
        elif board['board'].count(1) == board['board'].count(2): turn_categories.append("X's Turn")
        else: turn_categories.append("O's Turn")
    
    full_turn_categories = turn_categories * len(mappings) # Tile for all mappings
    category_map = {"X's Turn": 'blue', "O's Turn": 'red', 'Game Over': 'grey'}
    colors = [category_map[c] for c in full_turn_categories]
    
    plt.figure(figsize=(20, 16))
    plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.6, s=20)
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in category_map.items()]
    plt.legend(handles=handles, title="Game Turn")
    plt.title(f'Invariance Plot Colored by Game Turn (Layer {layer}, Style: {prompt_style})')
    plt.savefig(os.path.join(output_dir, f'viz_invariance_colored_by_turn_{prompt_style}.png')); plt.close()

    # --- Plot 3: Invariance colored by GAME PROGRESSION ---
    progression_categories = []
    for board in boards_to_test:
        pc = 9 - board['board'].count(0)
        if pc <= 2: progression_categories.append("Early Game")
        elif pc <= 5: progression_categories.append("Mid Game")
        else: progression_categories.append("Late Game")
        
    full_prog_categories = progression_categories * len(mappings)
    prog_map = {"Early Game": '#2ca02c', "Mid Game": '#ff7f0e', "Late Game": '#9467bd'}
    colors = [prog_map[c] for c in full_prog_categories]

    plt.figure(figsize=(20, 16))
    plt.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=colors, alpha=0.6, s=20)
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col) for cat, col in prog_map.items()]
    plt.legend(handles=handles, title="Game Progression")
    plt.title(f'Invariance Plot Colored by Game Progression (Layer {layer}, Style: {prompt_style})')
    plt.savefig(os.path.join(output_dir, f'viz_invariance_colored_by_progression_{prompt_style}.png')); plt.close()
    print(f"Token invariance visualizations for style '{prompt_style}' saved to {output_dir}")

def create_winner_plot_with_best_move_labels(reduced_activations, sampled_boards, output_dir, layer, prompt_style='text_instruction'):
    """
    Creates a single, readable plot showing game outcomes colored by winner
    and labeled with a sample of best moves.
    """
    print("Generating winner plot with best move labels...")

    # --- Step 1: Recreate the base coloring for the plot ---
    winner_categories = []
    for board in sampled_boards:
        winner = board.get('winner', 0)
        if winner == 1: winner_categories.append('X Wins')
        elif winner == 2: winner_categories.append('O Wins')
        else: winner_categories.append('Draw / In-Progress')

    winner_map = {'X Wins': '#FFD700', 'O Wins': '#FF8C00', 'Draw / In-Progress': '#D3D3D3'}
    winner_colors = [winner_map[c] for c in winner_categories]

    # --- Step 2: Create the plot from scratch ---
    fig, ax = plt.subplots(figsize=(20, 16))
    ax.scatter(reduced_activations[:, 0], reduced_activations[:, 1], c=winner_colors, alpha=0.6, s=25, label=None)

    # --- Step 3: Add a sample of non-overlapping text labels ---
    num_labels = 40  # A manageable number for readability
    indices_to_label = random.sample(range(len(sampled_boards)), min(num_labels, len(sampled_boards)))
    
    # Helper to format moves for the labels
    def format_moves(move_tokens):
        if not move_tokens: return "None"
        return str(sorted([(t - 1) % 9 + 1 for t in move_tokens]))

    texts = []
    for i in indices_to_label:
        best_moves_str = format_moves(sampled_boards[i].get('best_moves', []))
        texts.append(ax.text(reduced_activations[i, 0], reduced_activations[i, 1], best_moves_str,
                             ha='center', va='center', fontsize=8,
                             bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="blue", lw=0.5, alpha=0.8)))

    # Use adjust_text to automatically prevent labels from overlapping
    adjust_text(texts, arrowprops=dict(arrowstyle='->', color='black', lw=0.5))

    # --- Step 4: Add legend and save ---
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat, markersize=10, markerfacecolor=col)
               for cat, col in winner_map.items()]
    ax.legend(handles=handles, title="Game Outcome")
    ax.set_title(f"Winner States with Best Move Labels (Layer {layer})")
    ax.grid(True)

    output_filename = os.path.join(output_dir, f"viz_hypothesis_winner_best_labels_{prompt_style}.png")
    plt.savefig(output_filename)
    plt.close(fig)
    print(f"Saved labeled winner plot to {output_filename}")

 
def visualize_spatial_grid_tsne(model, task, sae, clusters, layer, display_name):
    """
    Uses t-SNE to visualize the 'concept vectors' for each of the 9 squares,
    testing for non-linear geometric structure.
    """
    print(f"\n--- VISUALIZING SPATIAL GEOMETRY (t-SNE) for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    square_concept_vectors = []; valid_square_indices = []
    for i in range(9):
        mean_vectors_for_square = []
        for piece in ['X', 'O', 'empty']:
            boards = task.find_boards_by_square_state(i, piece)
            if len(boards) > 5:
                sample_size = min(50, len(boards))
                prompts = [task.get_prompt(b) for b in random.sample(boards, sample_size)]
                recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
                mean_vectors_for_square.append(recons.mean(dim=0))
        if not mean_vectors_for_square: continue
        valid_square_indices.append(i + 1)
        square_concept_vectors.append(torch.stack(mean_vectors_for_square).mean(dim=0))

    if len(square_concept_vectors) < 3: return

    all_concepts_tensor = torch.stack(square_concept_vectors)
    # Adjust perplexity for t-SNE; must be less than the number of samples
    perplexity = min(30, len(all_concepts_tensor) - 1)
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    reduced_concepts = tsne.fit_transform(all_concepts_tensor.cpu().detach().numpy())
    
    plt.figure(figsize=(10, 10))
    plt.scatter(reduced_concepts[:, 0], reduced_concepts[:, 1], s=120, c=valid_square_indices, cmap='viridis')
    
    for i, square_num in enumerate(valid_square_indices):
        plt.annotate(f"Square {square_num}", xy=(reduced_concepts[i, 0], reduced_concepts[i, 1]),
                     xytext=(15, 15), textcoords='offset points', ha='center',
                     arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_spatial_grid_geometry_tsne.png')
    plt.title(f't-SNE of Spatial Concept Vectors for Each Square (Layer {layer})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
    plt.savefig(filename); plt.close()
    print(f"Spatial grid t-SNE visualization saved to {filename}")

def visualize_grid_topology_heatmap(model, task, sae, clusters, layer, display_name):
    """
    Calculates and plots a heatmap of cosine similarities between the 'X' concept
    vectors of each square to test for topological understanding.
    """
    print(f"\n--- VISUALIZING GRID TOPOLOGY (Heatmap) for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    # FIX: We now create a concept vector for 'X' at each square separately.
    # This is a much cleaner test of spatial representation.
    square_x_concept_vectors = []
    for i in range(9):
        boards = task.find_boards_by_square_state(i, 'X')
        if len(boards) > 5:
            prompts = [task.get_prompt(b) for b in random.sample(boards, min(50, len(boards)))]
            recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
            square_x_concept_vectors.append(recons.mean(dim=0))
        else:
            # Append a zero vector if no data, to keep indices aligned
            print(f"Skipping heatmap vector for square {i+1}: Not enough data.")
            square_x_concept_vectors.append(torch.zeros(model.cfg.d_model))

    all_concepts_tensor = torch.stack(square_x_concept_vectors)
    
    # Normalize vectors to have unit norm for cosine similarity calculation
    # Add a small epsilon to avoid division by zero for the zero vectors
    norm_concepts = all_concepts_tensor / (all_concepts_tensor.norm(dim=1, keepdim=True) + 1e-8)
    similarity_matrix = torch.matmul(norm_concepts, norm_concepts.T).cpu().numpy()

    plt.figure(figsize=(12, 10))
    sns.heatmap(similarity_matrix, annot=True, cmap='viridis', fmt=".2f",
                xticklabels=range(1, 10), yticklabels=range(1, 10))
    
    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_grid_topology_heatmap.png')
    plt.title(f"Cosine Similarity of 'X' Representation Across Squares (Layer {layer})")
    plt.xlabel("Square Index"); plt.ylabel("Square Index")
    plt.savefig(filename); plt.close()
    print(f"Grid topology heatmap saved to {filename}")
    
def visualize_spatial_grid_representation_tsne(model, task, sae, clusters, layer, display_name):
    """Generates a t-SNE plot of the 'concept vectors' for each of the 9 squares."""
    print(f"\n--- VISUALIZING SPATIAL GEOMETRY (t-SNE) of the grid for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    square_concept_vectors = []
    valid_square_indices = []
    for i in range(9):
        mean_vectors_for_square = []
        for piece in ['X', 'O', 'empty']:
            boards = task.find_boards_by_square_state(i, piece)
            if len(boards) > 5:
                sample_size = min(50, len(boards))
                prompts = [task.get_prompt(b) for b in random.sample(boards, sample_size)]
                recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, 16)
                mean_vectors_for_square.append(recons.mean(dim=0))

        if not mean_vectors_for_square:
            continue
        
        valid_square_indices.append(i + 1)
        square_concept_vectors.append(torch.stack(mean_vectors_for_square).mean(dim=0))

    if len(square_concept_vectors) < 3:
        print("Could not generate enough concept vectors. Aborting spatial visualization.")
        return

    all_concepts_tensor = torch.stack(square_concept_vectors)
    
    # Perplexity must be less than the number of samples.
    perplexity_value = min(5, len(all_concepts_tensor) - 1)
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity_value, n_iter=1000)
    reduced_concepts = tsne.fit_transform(all_concepts_tensor.cpu().detach().numpy())
    
    plt.figure(figsize=(10, 10))
    plt.scatter(reduced_concepts[:, 0], reduced_concepts[:, 1], s=120, c=valid_square_indices, cmap='viridis')
    
    for i, square_num in enumerate(valid_square_indices):
        plt.annotate(
            f"Square {square_num}", xy=(reduced_concepts[i, 0], reduced_concepts[i, 1]),
            xytext=(15, 15), textcoords='offset points', ha='center',
            arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2")
        )

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_spatial_grid_geometry_tsne.png')

    plt.title(f't-SNE of Spatial Concept Vectors for Each Square (Layer {layer})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
    plt.savefig(filename); plt.close()

    print(f"Spatial grid geometry t-SNE visualization saved to {filename}")

def visualize_decomposed_spatial_concepts_tsne(model, task, sae, clusters, layer, display_name):
    """Generates a t-SNE plot of 27 concept vectors for each piece type on each square."""
    print(f"\n--- VISUALIZING DECOMPOSED SPATIAL CONCEPTS (t-SNE) for layer {layer} ---")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    decomposed_vectors = []
    plot_labels = []

    for i in range(9):
        for piece in ['X', 'O', 'empty']:
            boards = task.find_boards_by_square_state(i, piece)
            if len(boards) > 5:
                sample_size = min(50, len(boards))
                prompts = [task.get_prompt(b) for b in random.sample(boards, sample_size)]
                recons = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, 16)
                decomposed_vectors.append(recons.mean(dim=0))
                plot_labels.append({'square': i + 1, 'piece': piece})

    if len(decomposed_vectors) < 3:
        print("Could not generate enough decomposed concept vectors. Aborting visualization.")
        return

    all_decomposed_tensor = torch.stack(decomposed_vectors)

    # Adjust perplexity for the number of available points
    perplexity_value = min(20, len(all_decomposed_tensor) - 1)
    
    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity_value, n_iter=1200, learning_rate='auto')
    reduced_decomposed = tsne.fit_transform(all_decomposed_tensor.cpu().detach().numpy())
    
    plt.figure(figsize=(15, 15))
    markers = {'X': 'x', 'O': 'o', 'empty': '.'}
    colors = plt.cm.get_cmap('tab10', 9)

    for i, label in enumerate(plot_labels):
        plt.scatter(
            reduced_decomposed[i, 0], reduced_decomposed[i, 1],
            marker=markers[label['piece']],
            color=colors(label['square'] - 1),
            s=150
        )
        plt.text(reduced_decomposed[i, 0] * 1.05, reduced_decomposed[i, 1] * 1.05, 
                 f"{label['square']}{label['piece'][0]}", fontsize=9)

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    filename = os.path.join(output_dir, 'viz_decomposed_concepts_tsne.png')

    plt.title(f't-SNE of Decomposed Spatial & Content Vectors (Layer {layer})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True)
    plt.savefig(filename); plt.close()
    
    print(f"Decomposed spatial concepts t-SNE visualization saved to {filename}")
    
def analyze_and_visualize_grid_tsne(model, task, sae, clusters, layer, display_name):
    """Generates t-SNE plots for all 9 squares of the board."""
    print(f"\n--- VISUALIZING GRID (Per-Square Content, t-SNE) for layer {layer} ---")
    print("Note: This may be slow as it runs t-SNE nine times.")
    if not clusters: return
    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]
    
    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    os.makedirs(output_dir, exist_ok=True)

    for i in range(9):
        boards_x = task.find_boards_by_square_state(i, 'X')
        boards_o = task.find_boards_by_square_state(i, 'O')
        boards_empty = task.find_boards_by_square_state(i, 'empty')

        sample_size = min(25, len(boards_x), len(boards_o), len(boards_empty))
        if sample_size < 10: # t-SNE needs a reasonable number of points
            print(f"Skipping square {i+1} due to insufficient data for t-SNE.")
            continue
            
        selected_boards = boards_x[:sample_size] + boards_o[:sample_size] + boards_empty[:sample_size]
        prompts = [task.get_prompt(b) for b in selected_boards]
        reconstructed_activations = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, 16)
        
        # Here we run t-SNE instead of PCA
        tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
        reduced_activations = tsne.fit_transform(reconstructed_activations.detach().numpy())
        
        plt.figure(figsize=(10, 8))
        plt.scatter(reduced_activations[:sample_size, 0], reduced_activations[:sample_size, 1], alpha=0.7, label=f'Square {i+1}: X')
        plt.scatter(reduced_activations[sample_size:2*sample_size, 0], reduced_activations[sample_size:2*sample_size, 1], alpha=0.7, label=f'Square {i+1}: O')
        plt.scatter(reduced_activations[2*sample_size:, 0], reduced_activations[2*sample_size:, 1], alpha=0.7, label=f'Square {i+1}: Empty')
        plt.title(f't-SNE of SAE Cluster (Layer {layer}) - Colored by State of Square {i+1}')
        plt.legend(); plt.grid(True)
        plt.savefig(os.path.join(output_dir, f'viz_square_{i+1}_content_tsne.png')); plt.close()
        
    print(f"Per-square content t-SNE visualizations saved to {output_dir}")

def visualize_illegal_legal_pattern_agglomeration(
    reduced_activations: np.ndarray,
    combined_records: list,
    cluster_labels: np.ndarray,
    output_dir: str,
    layer: int,
    prompt_style: str,
    min_cluster_size: int = 30,
    min_pattern_size: int = 15,
    purity_threshold: float = 0.5,
    max_patterns_per_cluster: int = 3,
    illegal_dominance_threshold: float = 0.7,
    create_illegal_focus_plot: bool = True
):
    """
    NEW: Visualizes legality-colored points (legal vs illegal) while annotating
    each K-Means cluster with its dominant line pattern groups discovered via
    lightweight agglomeration (grouping boards sharing the SAME line pattern).

    Strategy:
      1. Keep point COLOR strictly for legality: legal=blue, illegal=red.
      2. For each cluster (k-means label), gather frequency of each (line_index, pattern_tuple).
      3. Select pattern groups meeting both min_pattern_size and (count/cluster_size >= purity_threshold).
      4. Annotate cluster centroid with up to top-N dominant pattern groups, showing:
         - Line name & pattern (converted to symbols)
         - Pattern share % of cluster
         - Legal vs Illegal counts within that pattern group
      5. Draw a light convex hull per cluster for spatial grouping without overriding point colors.

    Parameters:
      min_cluster_size: Skip clusters smaller than this (likely noise)
      min_pattern_size: Minimum absolute size for a pattern group to consider
      purity_threshold: Minimum fraction (pattern_count / cluster_size) for display
      max_patterns_per_cluster: Limit number of pattern groups annotated per cluster
    """
    print("Building legality-colored pattern agglomeration plot...")
    try:
        from scipy.spatial import ConvexHull as _ConvexHull
    except Exception:
        _ConvexHull = None

    LINES = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
    LINE_NAMES = {0:"R0",1:"R1",2:"R2",3:"C0",4:"C1",5:"C2",6:"D1",7:"D2"}
    PIECE_MAP = {0:'.',1:'X',2:'O'}

    fig, ax = plt.subplots(figsize=(24, 20))
    legality_color_map = {'legal': '#1f77b4', 'illegal': '#d62728'}
    labels_list = [rec['label_type'] for rec in combined_records]
    point_colors = [legality_color_map.get(l, 'grey') for l in labels_list]
    ax.scatter(reduced_activations[:,0], reduced_activations[:,1], c=point_colors, s=18, alpha=0.7, edgecolors='none')

    cluster_ids = sorted(set(cluster_labels))
    texts = []
    annotation_summaries = {}

    for cid in cluster_ids:
        mask = (cluster_labels == cid)
        indices = np.where(mask)[0]
        if len(indices) < min_cluster_size:
            continue

        cluster_points = reduced_activations[indices]
        # Draw hull for context
        if _ConvexHull is not None and len(cluster_points) > 2:
            try:
                hull = _ConvexHull(cluster_points)
                hull_pts = cluster_points[hull.simplices]
                for simplex in hull.simplices:
                    ax.plot(cluster_points[simplex,0], cluster_points[simplex,1], color='black', lw=1.0, alpha=0.25)
            except Exception:
                pass

        # Count line patterns within cluster
        pattern_counter = Counter()
        pattern_legal_illegal = defaultdict(lambda: {'legal':0, 'illegal':0})
        for idx in indices:
            board = combined_records[idx]['board']
            label_type = combined_records[idx]['label_type']
            for li, line in enumerate(LINES):
                pattern_tuple = tuple(board[i] for i in line)
                key = (li, pattern_tuple)
                pattern_counter[key] += 1
                pattern_legal_illegal[key][label_type] += 1

        cluster_size = len(indices)
        pattern_stats = []  # (key, count, fraction_of_cluster, legal_count, illegal_count, illegal_ratio)
        for key, count in pattern_counter.items():
            if count < min_pattern_size:
                continue
            fraction = count / cluster_size
            if fraction < purity_threshold:
                continue
            legal_c = pattern_legal_illegal[key]['legal']
            illegal_c = pattern_legal_illegal[key]['illegal']
            total_li = legal_c + illegal_c if (legal_c + illegal_c) > 0 else 1
            illegal_ratio = illegal_c / total_li
            pattern_stats.append((key, count, fraction, legal_c, illegal_c, illegal_ratio))

        # Sort by fraction then count
        pattern_stats.sort(key=lambda x: (x[2], x[1]), reverse=True)
        if not pattern_stats:
            continue

        centroid = cluster_points.mean(axis=0)
        lines_text = []
        selected_patterns = []
        for (li, patt), count, frac, legal_c, illegal_c, illegal_ratio in pattern_stats[:max_patterns_per_cluster]:
            # If illegal dominance threshold specified, skip patterns below it
            if illegal_dominance_threshold is not None and illegal_ratio < illegal_dominance_threshold:
                continue
            patt_str = ''.join(PIECE_MAP.get(x,'?') for x in patt)
            line_name = LINE_NAMES.get(li, f'L{li}')
            lines_text.append(f"{line_name}:{patt_str} {frac*100:.0f}% L{legal_c}/I{illegal_c} ({illegal_ratio*100:.0f}% I)")
            selected_patterns.append((li, patt, count, frac, legal_c, illegal_c, illegal_ratio))
        if not selected_patterns:
            # Skip annotation entirely if no pattern passes illegal dominance filter
            continue
        label_text = f"C{cid} (n={cluster_size})\n" + '\n'.join(lines_text)
        texts.append(ax.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                             fontsize=8, bbox=dict(boxstyle='round,pad=0.45', fc='white', ec='black', alpha=0.85)))
        annotation_summaries[cid] = {
            'cluster_size': cluster_size,
            'patterns': [
                {
                    'line_index': li,
                    'line_name': LINE_NAMES.get(li, f'L{li}'),
                    'pattern': patt,
                    'pattern_str': ''.join(PIECE_MAP.get(x,'?') for x in patt),
                    'count': count,
                    'fraction': frac,
                    'legal_count': legal_c,
                    'illegal_count': illegal_c,
                    'illegal_ratio': illegal_ratio
                } for (li, patt, count, frac, legal_c, illegal_c, illegal_ratio) in selected_patterns
            ],
            'illegal_dominance_threshold': illegal_dominance_threshold
        }

    if texts:
        try:
            adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', lw=0.5, alpha=0.6))
        except Exception:
            pass

    from matplotlib.lines import Line2D
    legend_handles = [
        Line2D([0],[0], marker='o', color='w', label=f'Legal (n={labels_list.count("legal")})', markerfacecolor=legality_color_map['legal'], markersize=10),
        Line2D([0],[0], marker='o', color='w', label=f'Illegal (n={labels_list.count("illegal")})', markerfacecolor=legality_color_map['illegal'], markersize=10)
    ]
    ax.legend(handles=legend_handles, title='Board Type', loc='upper left', bbox_to_anchor=(1.02,1))
    ax.set_title(f'Legal vs Illegal with Pattern Agglomeration (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True, alpha=0.25)
    plt.tight_layout(rect=[0,0,0.82,1])
    out_path = os.path.join(output_dir, f'viz_illegal_legal_pattern_agglomeration_{prompt_style}.png')
    plt.savefig(out_path); plt.close(fig)
    print(f"Saved {out_path}")

    # Save JSON summary
    try:
        with open(os.path.join(output_dir, f'illegal_legal_pattern_agglomeration_{prompt_style}.json'), 'w') as f:
            json.dump({cid: {
                'cluster_size': meta['cluster_size'],
                'patterns': [
                    {
                        'line_index': p['line_index'],
                        'line_name': p['line_name'],
                        'pattern': p['pattern'],
                        'pattern_str': p['pattern_str'],
                        'count': p['count'],
                        'fraction': p['fraction'],
                        'legal_count': p['legal_count'],
                        'illegal_count': p['illegal_count']
                    } for p in meta['patterns']
                ]
            } for cid, meta in annotation_summaries.items()}, f, indent=2)
    except Exception as e:
        print(f"Failed to save pattern agglomeration JSON: {e}")
    # Optional: Focused plot of only annotated (illegal-dominant) clusters
    if create_illegal_focus_plot and annotation_summaries:
        fig2, ax2 = plt.subplots(figsize=(24, 20))
        # Desaturate all points first
        ax2.scatter(reduced_activations[:,0], reduced_activations[:,1], c='lightgrey', s=10, alpha=0.15, edgecolors='none')
        focus_texts = []
        for cid, meta in annotation_summaries.items():
            mask = (cluster_labels == cid)
            pts = reduced_activations[mask]
            if len(pts) == 0:
                continue
            # color intensity proportional to mean illegal ratio among kept patterns
            mean_illegal_ratio = np.mean([p['illegal_ratio'] for p in meta['patterns']]) if meta['patterns'] else 0.0
            color = plt.cm.Reds(min(0.99, 0.3 + 0.7*mean_illegal_ratio))
            ax2.scatter(pts[:,0], pts[:,1], c=[color], s=20, alpha=0.85, edgecolors='none')
            centroid = pts.mean(axis=0)
            lines_text = []
            for p in meta['patterns']:
                lines_text.append(f"{p['line_name']}:{p['pattern_str']} {p['fraction']*100:.0f}% I{p['illegal_count']}/L{p['legal_count']} ({p['illegal_ratio']*100:.0f}% I)")
            label_text = f"C{cid} (n={meta['cluster_size']})\n" + '\n'.join(lines_text)
            focus_texts.append(ax2.text(centroid[0], centroid[1], label_text, ha='center', va='center', fontsize=8,
                                        bbox=dict(boxstyle='round,pad=0.45', fc='white', ec='black', alpha=0.9)))
        if focus_texts:
            try:
                adjust_text(focus_texts, ax=ax2, arrowprops=dict(arrowstyle='-', color='gray', lw=0.5, alpha=0.6))
            except Exception:
                pass
        ax2.set_title(f'Illegal-Dominant Pattern Clusters (>= {illegal_dominance_threshold*100:.0f}% illegal)\nLayer {layer}, Style: {prompt_style}')
        ax2.set_xlabel('t-SNE Dimension 1'); ax2.set_ylabel('t-SNE Dimension 2'); ax2.grid(True, alpha=0.25)
        focus_path = os.path.join(output_dir, f'viz_illegal_dominant_pattern_clusters_{prompt_style}.png')
        plt.savefig(focus_path); plt.close(fig2)
        print(f"Saved {focus_path}")

    return annotation_summaries

def run_hybrid_illegal_legal_hierarchical_agglomerative_analysis(
    reduced_activations: np.ndarray,
    combined_records: list,
    output_dir: str,
    layer: int,
    prompt_style: str,
    n_l0_clusters: int = 18,
    micro_min_size: int = 8,
    micro_cluster_divisor: int = 12,
    pattern_purity_threshold: float = 0.65,
    dominant_coverage_threshold: float = 0.90,
    illegal_ratio_highlight: float = 0.70
):
    """Hybrid hierarchical+agglomerative analysis tailored for illegal vs legal data.

    Steps:
      1. Level-0 KMeans (n_l0_clusters) on all points.
      2. For each L0 cluster, run a higher-granularity micro KMeans where k is adaptive.
      3. For each micro-cluster, compute dominant line pattern & purity plus legal/illegal breakdown.
      4. Retain micro-clusters whose dominant line pattern purity >= pattern_purity_threshold.
      5. Merge micro-clusters within the same L0 that share the same dominant (line_index, pattern_tuple).
      6. Visualize: outline L0 cluster hulls; inside plot merged pattern groups colored by illegal ratio.
      7. Annotate groups (pattern, purity %, illegal%, size).

    Returns summary dict and saves PNG + JSON.
    """
    print("Running hybrid illegal/legal hierarchical-agglomerative analysis ...")
    if reduced_activations is None or len(reduced_activations) == 0:
        return {}

    from collections import defaultdict as _dd
    from sklearn.cluster import KMeans as _HybridKMeans
    try:
        from scipy.spatial import ConvexHull as _ConvHull
    except Exception:
        _ConvHull = None

    LINES = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
    LINE_NAMES = {0:"R0",1:"R1",2:"R2",3:"C0",4:"C1",5:"C2",6:"D1",7:"D2"}
    PIECE_MAP = {0:'.',1:'X',2:'O'}

    # L0 clustering
    kmeans_l0 = _HybridKMeans(n_clusters=n_l0_clusters, random_state=42, n_init='auto').fit(reduced_activations)
    l0_labels = kmeans_l0.labels_

    # Helper: compute dominant line pattern stats for a set of indices
    def dominant_pattern(indices):
        if len(indices) == 0:
            return None
        counter = Counter()
        legal_illegal = {}
        for idx in indices:
            board = combined_records[idx]['board']
            label_type = combined_records[idx]['label_type']
            for li, line in enumerate(LINES):
                patt = tuple(board[i] for i in line)
                key = (li, patt)
                counter[key] += 1
                if key not in legal_illegal:
                    legal_illegal[key] = {'legal':0, 'illegal':0}
                legal_illegal[key][label_type] += 1
        if not counter:
            return None
        top_key, top_count = counter.most_common(1)[0]
        purity = top_count / len(indices)
        li, patt = top_key
        legal_c = legal_illegal[top_key]['legal']
        illegal_c = legal_illegal[top_key]['illegal']
        return {
            'line_index': li,
            'pattern_tuple': patt,
            'purity': purity,
            'count': top_count,
            'size': len(indices),
            'legal_count': legal_c,
            'illegal_count': illegal_c,
            'illegal_ratio': (illegal_c / (legal_c + illegal_c)) if (legal_c + illegal_c) > 0 else 0.0
        }

    # Collect micro cluster groups per L0
    l0_to_groups = _dd(list)
    for cid in range(n_l0_clusters):
        l0_mask = (l0_labels == cid)
        l0_indices = np.where(l0_mask)[0]
        if len(l0_indices) < micro_min_size:
            continue
        activ_sub = reduced_activations[l0_indices]
        # adaptive micro cluster count
        k_micro = max(3, min(25, len(l0_indices)//micro_cluster_divisor))
        if k_micro >= len(l0_indices):
            k_micro = max(2, len(l0_indices)//2)
        try:
            micro_kmeans = _HybridKMeans(n_clusters=k_micro, random_state=cid+7, n_init='auto').fit(activ_sub)
        except Exception:
            continue
        micro_labels = micro_kmeans.labels_
        for mc in range(k_micro):
            mc_local_mask = (micro_labels == mc)
            mc_indices_local = np.where(mc_local_mask)[0]
            if len(mc_indices_local) < micro_min_size:
                continue
            global_indices = l0_indices[mc_indices_local]
            dom = dominant_pattern(global_indices)
            if not dom:
                continue
            if dom['purity'] < pattern_purity_threshold:
                continue
            l0_to_groups[cid].append({'indices': global_indices, 'dominant': dom})

    # Merge groups with identical dominant pattern inside each L0
    merged_results = {}
    for cid, groups in l0_to_groups.items():
        pattern_key_to_indices = _dd(list)
        pattern_stats_acc = {}
        for g in groups:
            d = g['dominant']
            key = (d['line_index'], d['pattern_tuple'])
            pattern_key_to_indices[key].extend(g['indices'])
        merged_groups = []
        for key, idx_list in pattern_key_to_indices.items():
            idx_list = list(set(idx_list))
            dom = dominant_pattern(idx_list)
            if not dom:
                continue
            merged_groups.append({'indices': np.array(idx_list), 'dominant': dom})
        merged_results[cid] = merged_groups

    # Visualization similar to hybrid breakdown
    fig, ax = plt.subplots(figsize=(30, 24))
    cmap_l0 = plt.get_cmap('tab20', n_l0_clusters)
    texts = []
    for cid in range(n_l0_clusters):
        l0_points = reduced_activations[l0_labels == cid]
        if len(l0_points) < 3:
            continue
        # L0 hull
        if _ConvHull is not None and len(l0_points) > 2:
            try:
                hull = _ConvHull(l0_points)
                for simplex in hull.simplices:
                    ax.plot(l0_points[simplex,0], l0_points[simplex,1], color=cmap_l0(cid), lw=2.5, alpha=0.5)
            except Exception:
                pass
        # Background points
        ax.scatter(l0_points[:,0], l0_points[:,1], c=[cmap_l0(cid)], s=8, alpha=0.08)
        # Annotate L0 dominant coverage (top line patterns covering threshold)
        # Reuse frequency logic for L0 overall
        l0_dom = dominant_pattern(np.where(l0_labels==cid)[0])
        if l0_dom:
            centroid_l0 = l0_points.mean(axis=0)
            patt_str = ''.join(PIECE_MAP.get(x,'?') for x in l0_dom['pattern_tuple'])
            txt = f"L0 {cid}\n{LINE_NAMES.get(l0_dom['line_index'])}:{patt_str} {l0_dom['purity']*100:.0f}%\nI{l0_dom['illegal_count']}/L{l0_dom['legal_count']}"
            texts.append(ax.text(centroid_l0[0], centroid_l0[1], txt, ha='center', va='center', fontsize=9,
                                 bbox=dict(boxstyle='round,pad=0.4', fc=cmap_l0(cid), ec='black', alpha=0.6), color='white'))
        # Sub-groups
        sub_groups = merged_results.get(cid, [])
        if not sub_groups:
            continue
        sub_cmap = plt.get_cmap('Reds')
        for g in sub_groups:
            dom = g['dominant']
            pts = reduced_activations[g['indices']]
            if len(pts) == 0:
                continue
            ir = dom['illegal_ratio']
            color = sub_cmap(min(0.99, 0.3 + 0.7*ir))  # deeper red if more illegal
            ax.scatter(pts[:,0], pts[:,1], c=[color], s=20, alpha=0.85, edgecolors='none')
            centroid = pts.mean(axis=0)
            patt_str = ''.join(PIECE_MAP.get(x,'?') for x in dom['pattern_tuple'])
            label = f"{LINE_NAMES.get(dom['line_index'])}:{patt_str}\nPur {dom['purity']*100:.0f}% I{dom['illegal_count']}/L{dom['legal_count']} ({ir*100:.0f}% I)" \
                    + (" *" if ir >= illegal_ratio_highlight else "")
            texts.append(ax.text(centroid[0], centroid[1], label, ha='center', va='center', fontsize=7,
                                 bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='black', alpha=0.9)))
    if texts:
        try:
            adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle='-', color='gray', lw=0.5, alpha=0.5))
        except Exception:
            pass
    ax.set_title(f'Hybrid Illegal vs Legal Pattern Hierarchy (Layer {layer}, Style: {prompt_style})')
    ax.set_xlabel('t-SNE Dimension 1'); ax.set_ylabel('t-SNE Dimension 2'); ax.grid(True, alpha=0.25)
    out_fig = os.path.join(output_dir, f'viz_illegal_legal_hybrid_hierarchy_{prompt_style}.png')
    plt.savefig(out_fig); plt.close(fig)
    print(f"Saved {out_fig}")

    # JSON summary
    json_summary = {}
    for cid, groups in merged_results.items():
        json_summary[str(cid)] = []
        for g in groups:
            d = g['dominant']
            json_summary[str(cid)].append({
                'indices_count': int(len(g['indices'])),
                'line_index': d['line_index'],
                'line_name': LINE_NAMES.get(d['line_index']),
                'pattern': list(d['pattern_tuple']),
                'pattern_str': ''.join(PIECE_MAP.get(x,'?') for x in d['pattern_tuple']),
                'purity': d['purity'],
                'size': d['size'],
                'legal_count': d['legal_count'],
                'illegal_count': d['illegal_count'],
                'illegal_ratio': d['illegal_ratio']
            })
    out_json = os.path.join(output_dir, f'illegal_legal_hybrid_hierarchy_{prompt_style}.json')
    try:
        with open(out_json, 'w') as f:
            json.dump(json_summary, f, indent=2)
        print(f"Saved {out_json}")
    except Exception as e:
        print(f"Failed to save hybrid hierarchy JSON: {e}")
    return json_summary

def plot_illegal_vs_legal(
    model,
    task,
    sae,
    clusters,
    layer: int,
    display_name: str,
    illegal_dataset_path: str,
    legal_dataset_path: str = DATASET_FULL_PATH,
    prompt_style: str = 'text_instruction',
    max_legal: int = 6000,
    max_illegal: int = 5000,
    random_seed: int = 42,
    cache: bool = True
):
    """
    NEW: Generates two t-SNE plots contrasting legal vs illegal boards.

    Plot 1: Legal (blue) vs Illegal (red).
    Plot 2: Legal (light grey) vs Illegal colored by reason-combination (each unique sorted tuple(reasons)).

    Assumptions about illegal dataset structure (adjust if generator differs):
      [
        { "board": [int x9], "reasons": ["DOUBLE_WIN", "COUNT_DIFF_GT1"], ... },
        ...
      ]
    Legal dataset is the main TicTacToe dataset used elsewhere (DATASET_FULL_PATH).

    We recompute a fresh t-SNE including BOTH legal+illegal boards to avoid projection drift; we use
    the largest SAE cluster mask (same as other visualization utilities) for reconstruction filtering.
    Caching: saves / loads from tsne_cache/illegal_merge_{sanitized}_{layer}_{prompt_style}_{legN}_{illN}.npz
    so repeated calls are fast if underlying files unchanged.
    """
    try:
        import hashlib
    except ImportError:
        hashlib = None

    seed_everything(random_seed)
    if not clusters:
        print("No clusters available; aborting illegal vs legal plot.")
        return

    largest_cluster_id = max(clusters, key=lambda k: len(clusters[k]))
    cluster_indices = clusters[largest_cluster_id]

    # --- Load legal dataset ---
    with open(legal_dataset_path, 'r') as f:
        legal_data_full = json.load(f)
    # Deduplicate by board
    legal_unique_map = {tuple(item['board']): item for item in legal_data_full}
    legal_items = list(legal_unique_map.values())
    random.shuffle(legal_items)
    if max_legal and len(legal_items) > max_legal:
        legal_items = legal_items[:max_legal]

    # --- Load illegal dataset ---
    if not os.path.exists(illegal_dataset_path):
        print(f"Illegal dataset not found at {illegal_dataset_path}")
        return
    with open(illegal_dataset_path, 'r') as f:
        illegal_data_raw = json.load(f)
    # Some generators may wrap data; accept either list or dict with 'data' key.
    if isinstance(illegal_data_raw, dict) and 'data' in illegal_data_raw:
        illegal_data_raw = illegal_data_raw['data']
    # Filter malformed
    illegal_clean = []
    for rec in illegal_data_raw:
        if not isinstance(rec, dict):
            continue
        board = rec.get('board')
        reasons = rec.get('reasons') or rec.get('reason') or rec.get('illegal_reasons')
        if board and isinstance(board, list) and len(board) == 9:
            if not isinstance(reasons, list):
                reasons = [reasons] if reasons else ["UNSPECIFIED"]
            # Normalize reason strings
            reasons_norm = [str(r).upper() for r in reasons]
            rec['reasons'] = reasons_norm
            illegal_clean.append(rec)
    # Deduplicate illegal boards and drop those that are actually legal dataset duplicates
    illegal_unique_map = {}
    for rec in illegal_clean:
        bt = tuple(rec['board'])
        if bt in legal_unique_map:
            continue  # skip boards that are legal (paranoia)
        if bt not in illegal_unique_map:
            illegal_unique_map[bt] = rec
    illegal_items = list(illegal_unique_map.values())
    random.shuffle(illegal_items)
    if max_illegal and len(illegal_items) > max_illegal:
        illegal_items = illegal_items[:max_illegal]

    if not illegal_items:
        print("No illegal items to plot after filtering.")
        return

    print(f"Preparing visualization with {len(legal_items)} legal and {len(illegal_items)} illegal boards.")

    # --- Prompt construction helpers ---
    def board_to_ascii_grid(board_list):
        # Apply invariance tokens only to board glyphs (not elsewhere) if active
        if INVARIANCE_ACTIVE and INVARIANCE_TOKEN_PAIR:
            x_tok, o_tok = INVARIANCE_TOKEN_PAIR
            symbol = {0:'.', 1:x_tok, 2:o_tok}
        else:
            symbol = {0:'.', 1:'X', 2:'O'}
        rows = []
        for r in range(3):
            rows.append(' '.join(symbol[board_list[3*r + c]] for c in range(3)))
        return '\n'.join(rows)

    def board_to_row_sentence(board_list):
        if INVARIANCE_ACTIVE and INVARIANCE_TOKEN_PAIR:
            x_tok, o_tok = INVARIANCE_TOKEN_PAIR
            symbol = {0:'empty', 1:x_tok, 2:o_tok}
        else:
            symbol = {0:'empty', 1:'X', 2:'O'}
        return f"Row 0: {symbol[board_list[0]]}, {symbol[board_list[1]]}, {symbol[board_list[2]]}. " \
               f"Row 1: {symbol[board_list[3]]}, {symbol[board_list[4]]}, {symbol[board_list[5]]}. " \
               f"Row 2: {symbol[board_list[6]]}, {symbol[board_list[7]]}, {symbol[board_list[8]]}."

    def build_prompt(board_rec, style):
        b = board_rec['board']
        if style == 'ascii_board':
            return board_to_ascii_grid(b)
        # text_instruction style (minimal) – keep close to task prompts
        row_text = board_to_row_sentence(b)
        p1 = b.count(1); p2 = b.count(2)
        next_player = 'Player 1' if p1 == p2 else 'Player 2'
        return f"Board state: {row_text} It is {next_player}'s turn."

    # Build combined list preserving labels
    combined = []
    for item in legal_items:
        combined.append({
            'board': item['board'],
            'label_type': 'legal',
            'prompt': build_prompt(item, prompt_style)
        })
    for item in illegal_items:
        combined.append({
            'board': item['board'],
            'label_type': 'illegal',
            'reasons': item.get('reasons', ['UNSPECIFIED']),
            'prompt': build_prompt(item, prompt_style)
        })

    # Cache key (optional) – incorporate counts + file mtimes so changes invalidate
    cache_dir = 'tsne_cache'
    os.makedirs(cache_dir, exist_ok=True)
    sanitized_name = sanitize_model_name(display_name)
    mtime_legal = os.path.getmtime(legal_dataset_path)
    mtime_illegal = os.path.getmtime(illegal_dataset_path)
    cache_sig_raw = f"{sanitized_name}|{layer}|{prompt_style}|{len(legal_items)}|{len(illegal_items)}|{mtime_legal:.0f}|{mtime_illegal:.0f}".encode()
    cache_hash = hashlib.md5(cache_sig_raw).hexdigest()[:10] if hashlib else 'nohash'
    cache_npz = os.path.join(cache_dir, f"illegal_merge_{cache_hash}.npz")
    cache_meta = os.path.join(cache_dir, f"illegal_merge_{cache_hash}.json")

    if cache and os.path.exists(cache_npz) and os.path.exists(cache_meta):
        try:
            print(f"Loading cached illegal vs legal t-SNE from {cache_npz}")
            with np.load(cache_npz) as data:
                reduced = data['reduced']
            with open(cache_meta, 'r') as f:
                cached_meta = json.load(f)
            if cached_meta.get('count_total') == len(combined):
                reduced_activations = reduced
            else:
                raise ValueError('Cache count mismatch; recomputing')
        except Exception as e:
            print(f"Cache load failed ({e}); recomputing.")
            reduced_activations = None
    else:
        reduced_activations = None

    if reduced_activations is None:
        prompts = [rec['prompt'] for rec in combined]
        reconstructed_acts = get_reconstructed_activations_batched(model, sae, prompts, cluster_indices, VISUALIZATION_BATCH_SIZE)
        # t-SNE
        perplexity = min(50, max(5, len(combined)//50))  # heuristic; ensures enough neighbors but not > n_samples
        perplexity = max(5, min(perplexity, len(combined)-5))
        print(f"Running t-SNE for illegal merge (n={len(combined)}, perplexity={perplexity})...")
        tsne = TSNE(n_components=2, perplexity=perplexity, learning_rate=200, n_iter=1000, random_state=42)
        reduced_activations = tsne.fit_transform(reconstructed_acts.detach().numpy())
        if cache:
            try:
                np.savez_compressed(cache_npz, reduced=reduced_activations)
                with open(cache_meta, 'w') as f:
                    json.dump({
                        'count_total': len(combined),
                        'n_legal': len(legal_items),
                        'n_illegal': len(illegal_items),
                        'prompt_style': prompt_style,
                        'layer': layer,
                        'model': sanitized_name,
                        'illegal_dataset': os.path.basename(illegal_dataset_path)
                    }, f, indent=2)
            except Exception as e:
                print(f"Failed to write cache: {e}")

    # Split indices
    labels = [rec['label_type'] for rec in combined]
    is_illegal = [1 if l == 'illegal' else 0 for l in labels]

    output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
    os.makedirs(output_dir, exist_ok=True)

    # --- Plot 1: Legal vs Illegal ---
    plt.figure(figsize=(20,16))
    legal_mask = [l == 'legal' for l in labels]
    illegal_mask = [l == 'illegal' for l in labels]
    arr = reduced_activations
    plt.scatter(arr[np.where(legal_mask)[0],0], arr[np.where(legal_mask)[0],1], c='#1f77b4', s=18, alpha=0.35, label=f'Legal (n={sum(legal_mask)})')
    plt.scatter(arr[np.where(illegal_mask)[0],0], arr[np.where(illegal_mask)[0],1], c='#d62728', s=30, alpha=0.75, label=f'Illegal (n={sum(illegal_mask)})')
    plt.legend(title='State Type')
    plt.title(f't-SNE: Legal vs Illegal Boards (Layer {layer}, Style: {prompt_style})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True, alpha=0.25)
    out1 = os.path.join(output_dir, f'viz_illegal_vs_legal_{prompt_style}.png')
    plt.savefig(out1); plt.close()
    print(f"Saved {out1}")

    # --- Plot 2: Illegal reason combinations ---
    # Build color map for reason combos
    reason_combos = []
    for rec in combined:
        if rec['label_type'] == 'illegal':
            combo = '+'.join(sorted(set(rec.get('reasons', ['UNSPECIFIED']))))
            reason_combos.append(combo)
        else:
            reason_combos.append(None)

    # Count illegal combos
    combo_counts = Counter([c for c in reason_combos if c])
    # Limit legend explosion: top K combos, rest grouped
    TOP_K = 18
    top_combos = {c for c,_ in combo_counts.most_common(TOP_K)}
    combo_labels_display = []
    for c in reason_combos:
        if c is None:
            combo_labels_display.append('LEGAL')
        elif c in top_combos:
            combo_labels_display.append(c)
        else:
            combo_labels_display.append('OTHER')

    unique_labels = [lab for lab in combo_labels_display if lab != 'LEGAL']
    unique_labels = list(dict.fromkeys(unique_labels))  # preserve order
    palette = sns.color_palette('tab20', n_colors=max(3, len(unique_labels)))
    color_map = {lab: palette[i % len(palette)] for i, lab in enumerate(unique_labels)}
    color_map['LEGAL'] = (0.8, 0.8, 0.8)
    if 'OTHER' in color_map:
        # Make OTHER a distinct dark grey
        color_map['OTHER'] = (0.3,0.3,0.3)

    point_colors = [color_map[lab] for lab in combo_labels_display]

    plt.figure(figsize=(24,18))
    plt.scatter(arr[:,0], arr[:,1], c=point_colors, s=22, alpha=0.75, edgecolors='none')
    # Build legend handles (skip LEGAL if too many?)
    handles = []
    # Ensure LEGAL first
    handles.append(plt.Line2D([0],[0], marker='o', color='w', label=f'LEGAL (n={sum(1 for l in combo_labels_display if l=="LEGAL")})', markersize=10, markerfacecolor=color_map['LEGAL']))
    for lab in unique_labels:
        if lab == 'LEGAL':
            continue
        count = combo_labels_display.count(lab)
        handles.append(plt.Line2D([0],[0], marker='o', color='w', label=f'{lab} (n={count})', markersize=9, markerfacecolor=color_map[lab]))
    plt.legend(handles=handles, title='Illegal Reason Combination', loc='upper left', bbox_to_anchor=(1.02,1))
    plt.title(f't-SNE: Illegal Reason Combinations (Layer {layer}, Style: {prompt_style})')
    plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True, alpha=0.25)
    plt.tight_layout(rect=[0,0,0.82,1])
    out2 = os.path.join(output_dir, f'viz_illegal_reason_combos_{prompt_style}.png')
    plt.savefig(out2); plt.close()
    print(f"Saved {out2}")

    # --- Plot 3: Cluster-based dominant line purity visualization (NEW) ---
    try:
        from sklearn.cluster import KMeans as _KMeansForIllegal
        print("Computing cluster-based line purity for legal+illegal boards...")
        n_points = reduced_activations.shape[0]
        # Heuristic for number of clusters: sqrt(n/40) capped between 8 and 35
        k_est = int(max(8, min(35, np.sqrt(n_points / 40))))
        kmeans_illegal_merge = _KMeansForIllegal(n_clusters=k_est, random_state=42, n_init='auto').fit(reduced_activations)
        cluster_labels = kmeans_illegal_merge.labels_

        # New clearer visualization: legality color + pattern agglomeration annotations
        visualize_illegal_legal_pattern_agglomeration(
            reduced_activations=reduced_activations,
            combined_records=combined,
            cluster_labels=cluster_labels,
            output_dir=output_dir,
            layer=layer,
            prompt_style=prompt_style,
            min_cluster_size=30,
            min_pattern_size=15,
            purity_threshold=0.50,
            max_patterns_per_cluster=3,
            illegal_dominance_threshold=0.70,
            create_illegal_focus_plot=True
        )

        LINES = [(0,1,2),(3,4,5),(6,7,8),(0,3,6),(1,4,7),(2,5,8),(0,4,8),(2,4,6)]
        LINE_NAMES = {0:"R0",1:"R1",2:"R2",3:"C0",4:"C1",5:"C2",6:"D1",7:"D2"}
        PIECE_MAP = {0:'.',1:'X',2:'O'}

        cluster_meta = {}
        for cid in range(k_est):
            idxs = np.where(cluster_labels == cid)[0]
            if len(idxs) == 0:
                continue
            line_pattern_counter = Counter()
            legal_count = 0
            illegal_count = 0
            for idx in idxs:
                board_list = combined[idx]['board']
                if combined[idx]['label_type'] == 'legal':
                    legal_count += 1
                else:
                    illegal_count += 1
                for li, line in enumerate(LINES):
                    pattern_tuple = tuple(board_list[i] for i in line)
                    line_pattern_counter[(li, pattern_tuple)] += 1
            top_patterns = line_pattern_counter.most_common(10)
            dominant_line_purity = top_patterns[0][1] / len(idxs) if top_patterns else 0.0
            cluster_meta[cid] = {
                'size': len(idxs),
                'legal_count': legal_count,
                'illegal_count': illegal_count,
                'dominant_line_purity': dominant_line_purity,
                'top_line_patterns': [
                    {
                        'line_index': p[0][0],
                        'line_name': LINE_NAMES.get(p[0][0], f'L{p[0][0]}'),
                        'pattern': list(p[0][1]),
                        'pattern_str': ''.join(PIECE_MAP.get(x,'?') for x in p[0][1]),
                        'count': p[1],
                        'pct': p[1] / len(idxs)
                    } for p in top_patterns
                ]
            }

        # Prepare purity color mapping
        purity_values = np.array([cluster_meta[c]['dominant_line_purity'] for c in cluster_labels])
        plt.figure(figsize=(20,16))
        sc = plt.scatter(reduced_activations[:,0], reduced_activations[:,1], c=purity_values, cmap='viridis', s=22, alpha=0.8)
        cbar = plt.colorbar(sc)
        cbar.set_label('Dominant Line Purity')
        # Annotate highly pure clusters
        for cid, meta in cluster_meta.items():
            if meta['dominant_line_purity'] >= 0.85 and meta['top_line_patterns']:
                centroid = reduced_activations[cluster_labels==cid].mean(axis=0)
                top = meta['top_line_patterns'][0]
                label_text = f"C{cid} {top['line_name']}:{top['pattern_str']}\n{meta['dominant_line_purity']*100:.0f}%"
                plt.text(centroid[0], centroid[1], label_text, ha='center', va='center',
                         fontsize=8, bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='black', alpha=0.85))
        plt.title(f't-SNE: Cluster Dominant Line Purity (k={k_est}) (Layer {layer}, Style: {prompt_style})')
        plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True, alpha=0.25)
        purity_path = os.path.join(output_dir, f'viz_illegal_merge_line_purity_{prompt_style}.png')
        plt.savefig(purity_path); plt.close()
        print(f"Saved {purity_path}")

        # --- Plot 4: Highlight clusters containing ONLY illegal boards ---
        illegal_only_clusters = [cid for cid, meta in cluster_meta.items() if meta['legal_count'] == 0 and meta['illegal_count'] > 0]
        if illegal_only_clusters:
            plt.figure(figsize=(20,16))
            # Fade all points
            plt.scatter(reduced_activations[:,0], reduced_activations[:,1], c='lightgrey', s=12, alpha=0.2, label='Other / Mixed')
            palette = sns.color_palette('tab10', n_colors=max(3, len(illegal_only_clusters)))
            legend_handles = []
            for i, cid in enumerate(illegal_only_clusters):
                mask = (cluster_labels == cid)
                color = palette[i % len(palette)]
                plt.scatter(reduced_activations[mask,0], reduced_activations[mask,1], c=[color], s=28, alpha=0.85, label=f'Cluster {cid}')
                meta = cluster_meta[cid]
                centroid = reduced_activations[mask].mean(axis=0)
                if meta['top_line_patterns']:
                    top = meta['top_line_patterns'][0]
                    label = f"C{cid} {top['line_name']}:{top['pattern_str']}\nPurity {meta['dominant_line_purity']*100:.0f}%\nSize {meta['size']}"
                    plt.text(centroid[0], centroid[1], label, ha='center', va='center', fontsize=8,
                             bbox=dict(boxstyle='round,pad=0.35', fc='white', ec=color, lw=1, alpha=0.9))
                    from matplotlib.lines import Line2D
                    legend_handles.append(Line2D([0],[0], marker='o', color='w', label=label.replace('\n', ' | '),
                                                 markerfacecolor=color, markersize=10))
            if legend_handles:
                plt.legend(handles=legend_handles, title='Illegal-Only Clusters', loc='upper left', bbox_to_anchor=(1.02,1))
            plt.title(f'Illegal-Only Clusters (Line Pattern Purity) (Layer {layer}, Style: {prompt_style})')
            plt.xlabel('t-SNE Dimension 1'); plt.ylabel('t-SNE Dimension 2'); plt.grid(True, alpha=0.25)
            plt.tight_layout(rect=[0,0,0.8,1])
            illegal_only_path = os.path.join(output_dir, f'viz_illegal_only_clusters_{prompt_style}.png')
            plt.savefig(illegal_only_path); plt.close()
            print(f"Saved {illegal_only_path}")
        else:
            illegal_only_path = None
            print("No illegal-only clusters identified.")

        # Save cluster meta analysis JSON
        analysis_json_path = os.path.join(output_dir, f'illegal_vs_legal_cluster_line_analysis_{prompt_style}.json')
        try:
            with open(analysis_json_path, 'w') as f:
                json.dump({cid: meta for cid, meta in cluster_meta.items()}, f, indent=2)
            print(f"Saved cluster line analysis JSON to {analysis_json_path}")
        except Exception as e:
            print(f"Failed to write analysis JSON: {e}")

        # --- Hybrid hierarchical + agglomerative analysis (NEW) ---
        try:
            run_hybrid_illegal_legal_hierarchical_agglomerative_analysis(
                reduced_activations=reduced_activations,
                combined_records=combined,
                output_dir=output_dir,
                layer=layer,
                prompt_style=prompt_style,
                n_l0_clusters=18,
                micro_min_size=10,
                micro_cluster_divisor=10,
                pattern_purity_threshold=0.60,
                dominant_coverage_threshold=0.90,
                illegal_ratio_highlight=0.70
            )
        except Exception as e:
            print(f"[WARN] Hybrid illegal/legal hierarchy failed: {e}")
    except Exception as e:
        print(f"[WARN] Failed to compute line purity / illegal-only cluster analysis: {e}")
        cluster_meta = {}
        illegal_only_clusters = []
        illegal_only_path = None

    # Optional: return data for further programmatic analysis (extended)
    return {
        'legal_count': len(legal_items),
        'illegal_count': len(illegal_items),
        'reason_combo_counts': combo_counts,
        'cache_used': os.path.exists(cache_npz) if cache else False,
        'embedding_shape': reduced_activations.shape,
        'cluster_meta': cluster_meta,
        'illegal_only_clusters': illegal_only_clusters
    }


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


# ==================== LEGEND EXTERNALIZATION (NEW) ====================
def apply_external_legends(bbox=(1.02, 1.0), loc='upper left'):
    """Monkey-patch matplotlib legend behavior so ALL future legends appear outside plots.

    This wraps both pyplot.legend and Axes.legend. Existing plotting functions that call
    plt.legend(...) or ax.legend(...) without specifying a bbox_to_anchor will now place
    legends outside the right margin automatically.

    Args:
        bbox: Tuple for bbox_to_anchor (default pushes legend to the right outside).
        loc:  Matplotlib legend location anchor relative to bbox.
    """
    import matplotlib.pyplot as _plt
    from functools import wraps as _wraps

    # Skip if already applied
    if getattr(_plt, '_external_legend_patched', False):
        print('[LegendPatch] External legend patch already active.')
        return

    original_pyplot_legend = _plt.legend

    @_wraps(original_pyplot_legend)
    def patched_pyplot_legend(*args, **kwargs):
        if 'bbox_to_anchor' not in kwargs:
            kwargs['bbox_to_anchor'] = bbox
        if 'loc' not in kwargs:
            kwargs['loc'] = loc
        # Reduce inside frame overlap risk
        if 'frameon' not in kwargs:
            kwargs['frameon'] = True
        return original_pyplot_legend(*args, **kwargs)

    _plt.legend = patched_pyplot_legend

    # Patch Axes.legend
    import matplotlib.axes
    original_axes_legend = matplotlib.axes.Axes.legend

    def patched_axes_legend(self, *args, **kwargs):
        if 'bbox_to_anchor' not in kwargs:
            kwargs['bbox_to_anchor'] = bbox
        if 'loc' not in kwargs:
            kwargs['loc'] = loc
        if 'frameon' not in kwargs:
            kwargs['frameon'] = True
        return original_axes_legend(self, *args, **kwargs)

    matplotlib.axes.Axes.legend = patched_axes_legend
    _plt._external_legend_patched = True
    print(f"[LegendPatch] Applied external legend patch (bbox={bbox}, loc='{loc}').")


# ==================== GAME-THEORETIC STRATEGIC CLASSIFICATION (NEW) ====================
def visualize_strategic_situation_aggregated_game_theoretic(
    reduced_activations: np.ndarray,
    boards_list: list,
    output_dir: str,
    layer: int,
    prompt_style: str,
    treat_nonterminal_draws_as_draw: bool = True,
    apply_style: bool = False
):
    """Improved aggregated strategic situation plot using full game-theoretic minimax evaluation.

    This addresses incorrect counts (e.g., very low draw state counts) by classifying EVERY
    non-terminal board according to perfect play outcome rather than only terminal boards.

    Categories:
        Player Won     : Terminal position with a winner.
        Guaranteed Win : Non-terminal; side to move can force a win (game-theoretic win).
        Draw           : (a) Terminal draw OR (b) game-theoretic draw for side to move
                         (if treat_nonterminal_draws_as_draw is True) AND no forced immediate block.
        Must Block     : Side to move must block an immediate opponent win to preserve
                         a draw (game-theoretic draw) OR is in a losing position with an
                         immediate opponent threat (triage urgency).
        To Play        : Neutral exploratory / eventual loss with no immediate opponent threat.

    Logic precedence order:
        1. Terminal -> Player Won / Draw
        2. Eval == Win -> Guaranteed Win
        3. Immediate opponent threat -> Must Block (unless terminal handled earlier)
        4. Eval == Draw -> Draw (if treat_nonterminal_draws_as_draw else To Play)
        5. Eval == Loss -> To Play (no immediate threat) / Must Block (if threat)

    We do NOT modify earlier functions; this is a new, corrected alternative.
    """
    if reduced_activations is None or len(reduced_activations) == 0 or not boards_list:
        print('[GT-Strategic] No data provided; skipping game-theoretic plot.')
        return

    if apply_style:
        try:
            apply_massive_plot_style()
        except Exception:
            pass

    # --- Helper functions (internal) ---
    win_lines = [(0,1,2), (3,4,5), (6,7,8), (0,3,6), (1,4,7), (2,5,8), (0,4,8), (2,4,6)]

    def check_terminal(board):
        for a,b,c in win_lines:
            if board[a] != 0 and board[a] == board[b] == board[c]:
                return board[a]  # winner 1 or 2
        if 0 not in board:
            return 0  # draw terminal
        return None  # non-terminal

    from functools import lru_cache

    @lru_cache(maxsize=None)
    def evaluate(board_tuple, player):
        """Return 1 if CURRENT player can force a win, 0 draw, -1 loss with perfect play."""
        board = list(board_tuple)
        term = check_terminal(board)
        if term is not None:
            if term == 0:
                return 0
            return 1 if term == player else -1
        # Generate moves
        next_moves = [i for i,v in enumerate(board) if v == 0]
        opponent = 2 if player == 1 else 1
        # If any move leads opponent to a losing eval -> current player win
        best_outcome = -2
        for mv in next_moves:
            board[mv] = player
            outcome = evaluate(tuple(board), opponent)
            board[mv] = 0
            # outcome returned from opponent POV; invert sign for current
            # Actually evaluate returns value for player to move at that node.
            # Our definition: evaluate(board, player) returns result from perspective of 'player' to move.
            # After making a move we call evaluate(child, opponent) -> result for opponent.
            # If opponent result is -1 -> current player has forced win.
            if outcome == -1:
                return 1
            if outcome == 0:
                best_outcome = max(best_outcome, 0)
        if best_outcome == 0:
            return 0
        return -1

    def opponent_has_immediate_win(board, player_to_move):
        opp = 2 if player_to_move == 1 else 1
        for a,b,c in win_lines:
            line = [board[a], board[b], board[c]]
            if line.count(opp) == 2 and line.count(0) == 1:
                return True
        return False

    categories = []
    for data in boards_list:
        board = data['board']
        term = check_terminal(board)
        p1_moves = board.count(1); p2_moves = board.count(2)
        player_to_move = 1 if p1_moves == p2_moves else 2

        if term is not None:
            if term == 0:
                categories.append('Draw')
            else:
                categories.append('Player Won')
            continue

        gt_eval = evaluate(tuple(board), player_to_move)
        immediate_threat = opponent_has_immediate_win(board, player_to_move)

        if gt_eval == 1:
            categories.append('Guaranteed Win')
        else:
            if immediate_threat:
                categories.append('Must Block')
            else:
                if gt_eval == 0 and treat_nonterminal_draws_as_draw:
                    categories.append('Draw')
                elif gt_eval == 0:
                    categories.append('To Play')
                else:  # gt_eval == -1 losing
                    categories.append('To Play')

    # Palette consistent with earlier aggregated function
    color_map = {
        'Player Won': '#0057B8',
        'Guaranteed Win': '#238823',
        'Must Block': '#D62828',
        'To Play': '#F77F00',
        'Draw': '#FFD60A'
    }

    coords = reduced_activations[:, :2]
    import matplotlib.pyplot as _plt
    _plt.figure(figsize=(24, 20))
    _plt.scatter(coords[:,0], coords[:,1], c=[color_map[c] for c in categories], s=55, alpha=0.82, edgecolors='none')

    from matplotlib.lines import Line2D as _Line2D
    handles = [
        _Line2D([0],[0], marker='o', color='w', label=f"{label} (n={categories.count(label)})", markerfacecolor=col, markersize=15)
        for label, col in color_map.items()
    ]
    _plt.legend(handles=handles, title='Strategic Situation (Game-Theoretic)', frameon=True)
    _plt.title(f'Game-Theoretic Strategic Situation (Layer {layer}, Style: {prompt_style})')
    _plt.xlabel('t-SNE Dimension 1'); _plt.ylabel('t-SNE Dimension 2'); _plt.grid(True, alpha=0.25)

    os.makedirs(output_dir, exist_ok=True)
    out_path = os.path.join(output_dir, f'viz_hypothesis_strategic_situation_aggregated_game_theoretic_{prompt_style}.png')
    _plt.savefig(out_path); _plt.close()
    print(f"[GT-Strategic] Saved game-theoretic strategic plot to {out_path}")

    return {
        'counts': {k: categories.count(k) for k in color_map.keys()},
        'total': len(categories),
        'output_path': out_path,
        'treat_nonterminal_draws_as_draw': treat_nonterminal_draws_as_draw
    }


def main():
    # -------------------- Argument Parsing (NEW) --------------------
    parser = argparse.ArgumentParser(description="Probe TicTacToe model representations and optionally visualize illegal boards.")
    parser.add_argument('--illegal-dataset', type=str, default=None, help='Path to JSON file containing illegal boards.')
    parser.add_argument('--illegal-max', type=int, default=4000, help='Max illegal boards to sample for visualization.')
    parser.add_argument('--legal-max', type=int, default=4000, help='Max legal boards to sample for illegal vs legal visualization.')
    parser.add_argument('--illegal-style', type=str, default='both', choices=['ascii_board','text_instruction','both'], help='Prompt style(s) for illegal vs legal visualization.')
    parser.add_argument('--no-illegal-cache', action='store_true', help='Disable caching of illegal vs legal t-SNE embeddings.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
    parser.add_argument('--invariance-tokens', type=str, choices=['PQ','AB'], default=None, help='If set, replace board X/O with provided token pair (PQ->P/Q, AB->A/B) only in board state. Outputs saved under visualizations_invariance/.')
    args, unknown = parser.parse_known_args()

    if unknown:
        print(f"[WARN] Ignoring unknown args: {unknown}")

    seed_everything(args.seed)

    # Configure invariance mode
    global INVARIANCE_ACTIVE, INVARIANCE_TOKEN_PAIR
    invariance_map = {'PQ': ('P','Q'), 'AB': ('A','B')}
    INVARIANCE_TOKEN_PAIR = invariance_map.get(args.invariance_tokens)
    INVARIANCE_ACTIVE = INVARIANCE_TOKEN_PAIR is not None
    if INVARIANCE_ACTIVE:
        print(f"[Invariance] Activated with tokens: {INVARIANCE_TOKEN_PAIR}. Visualization root: {get_viz_root()}")
    else:
        print("[Invariance] Not active. Using canonical X/O board states.")

    task = TicTacToeTask(dataset_path=DATASET_FULL_PATH, invariance_token_pair=INVARIANCE_TOKEN_PAIR)
    for model_info in MODELS_TO_PROBE:
        display_name = model_info["display_name"]
        load_path = model_info["load_path"]
        architecture_name = model_info["architecture_name"]
        
        # Load the inference results ONCE per model
        inference_path = get_inference_path(model_info)
        inference_results = load_inference_results(inference_path)
        
        print(f"\n{'='*20} PROBING MODEL: {display_name} {'='*20}")
        model = load_model(load_path, architecture_name)
        
        # ==================== ADD THIS VERIFICATION BLOCK ====================
        print("\n--- Verifying Model Behavior ---")
        # This prompt is specific to your Tic-Tac-Toe task format
        verification_prompt = ("""
You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>
...
</think>
<answer>
...
</answer>
You are a helpful assistant skilled at reasoning for tic tac toe. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>
...
</think>
<answer>
...
</answer>.
Board state:
Row 0: X, X, O. Row 1: empty, empty, O. Row 2: O, X, X.
It is Player 2's turn.
Recommend the best move which the player can play. Here is the definition of best move:
The determination of a "best move" follows a strict hierarchy of objectives:

1.  **Priority 1: Fastest Win.** If the current player can force a win, the 'best_move' will be the move that leads to the quickest possible victory (a win in the minimum number of subsequent turns).

2.  **Priority 2: Secure a Draw.** If a win is not possible, but the player can force a draw, the 'best_move' will contain all moves that guarantee at least a draw.

3.  **Priority 3: Slowest Loss.** If the player is in a losing position where every move leads to an eventual loss, the 'best_move' will be the move that prolong the game as long as possible before the loss occurs.

4.  **Terminal State: No Moves.** If the game has already concluded (a player has won, or the board is full), no further moves can be made. In this case, the 'best_move' will be None.Mapping:
Player 1 (X) Tokens:
1 -> (0,0), 2 -> (0,1), 3 -> (0,2), 4 -> (1,0), 5 -> (1,1), 6 -> (1,2), 7 -> (2,0), 8 -> (2,1), 9 -> (2,2)
Player 2 (O) Tokens:
10 -> (0,0), 11 -> (0,1), 12 -> (0,2), 13 -> (1,0), 14 -> (1,1), 15 -> (1,2), 16 -> (2,0), 17 -> (2,1), 18 -> (2,2), None -> No Move can be played
Thus your final answer should be one of the following if the next player to move is player 1: 1 or 2 or 3 or 4 or 5 or 6 or 7 or 8 or 9 or None
And your final answer should be one of the following if the next player to move is player 2: 10 or 11 or 12 or 13 or 14 or 15 or 16 or 17 or 18 or None
Please provide your reasoning in the following format:
<think> Your chain-of-thought reasoning here </think>
<answer> Your final move here </answer>
Remember to output exactly one of the best moves. You can only have one set of <think>...</think> and <answer>...</answer> in your response. The think section should be at the beginning of your response.
                               """
        )
        
        # Generate text using the loaded model
        generated_text = model.generate(verification_prompt, max_new_tokens=256, temperature=0.1)
        print(f"PROMPT: {verification_prompt}")
        print(f"GENERATED TEXT: {generated_text}")
        print("--- Verification Complete ---\n")
        # ===================================================================
        
        for layer in LAYERS_TO_PROBE:
            print(f"\n--- Starting analysis for Layer {layer} ---")
            
            sae, clusters = run_feature_discovery(
                model=model,
                display_name=display_name,
                architecture_name=architecture_name,
                layer=layer,
                device=DEVICE
            )
            
            output_dir = f"{get_viz_root()}/{sanitize_model_name(display_name)}/layer_{layer}"
            
            # --- Run visualizations and cluster analyses for each prompt style ---
            for style in ['text_instruction', 'ascii_board']:
                # This function handles the main t-SNE run and generates most hypothesis plots.
                # It returns the t-SNE results so other functions can reuse them.
                
                # First, we still need to get the t-SNE results once so we can use it try for the heirarchical hypothesis.
                # We can reuse visualize_board_state_hypotheses for this, but only run a minimal set.
                reduced_activations, sampled_boards = visualize_board_state_hypotheses(
                    model, task, sae, clusters, layer, display_name, prompt_style=style,
                    # hypotheses_to_run={'full_board', 'kmeans_dbscan'} # Just run the t-SNE and the basic plot
                )
            
                # This function call tests token invariance for each style.
                # visualize_invariance_with_random_chars(
                #     model, task, sae, clusters, layer, display_name, prompt_style=style
                # )
            
                # This labeling function reuses the t-SNE results to create a unique plot
                # combining winner status with best move labels. It is still valuable.
                if reduced_activations is not None and sampled_boards is not None:
                    # Call the new visualization function, passing the pre-loaded data
                    visualize_prediction_correctness(
                        reduced_activations,
                        sampled_boards,
                        inference_results, 
                        output_dir,
                        layer,
                        style
                    )
                    
                    create_winner_plot_with_best_move_labels(
                        reduced_activations, sampled_boards, output_dir, layer, prompt_style=style
                    )
                    
                    # Now, run the new hierarchical analysis on the t-SNE results
                    run_hierarchical_analysis(
                        reduced_activations,
                        sampled_boards,
                        output_dir,
                        layer,
                        style
                    )
                    
                    # Pure agglomerative analysis
                    # run_agglomerative_analysis(
                    #     reduced_activations,
                    #     sampled_boards,
                    #     output_dir,
                    #     layer,
                    #     style
                    # )
            
                    run_hybrid_hierarchical_agglomerative_analysis(
                        reduced_activations,
                        sampled_boards,
                        output_dir,
                        layer,
                        style,
                        return_components=True
                    )
                    
                    print(f"Starting illegal vs legal visualization for style '{style}'")

                    # --------------- Illegal vs Legal Visualization (NEW) ---------------

                    style_ok = (args.illegal_style == 'both') or (args.illegal_style == style)
                    if style_ok:
                        try:
                            print(f"\n[IllegalViz] Generating illegal vs legal plots for style '{style}' using {args.illegal_dataset}")
                            plot_illegal_vs_legal(
                                model=model,
                                task=task,
                                sae=sae,
                                clusters=clusters,
                                layer=layer,
                                display_name=display_name,
                                illegal_dataset_path=ILLEGAL_DATASET_FULL_PATH,
                                legal_dataset_path=DATASET_FULL_PATH,
                                prompt_style=style,
                                max_legal=args.legal_max if args.legal_max > 0 else None,
                                max_illegal=args.illegal_max if args.illegal_max > 0 else None,
                                random_seed=args.seed,
                                cache=not args.no_illegal_cache
                            )
                        except Exception as e:
                            print(f"[IllegalViz] Failed to generate illegal vs legal visualization: {e}")
                    else:
                        print(f"[IllegalViz] Skipping style '{style}' (filter={args.illegal_style})")


if __name__ == "__main__":
    # Apply large fonts and externalize legends globally BEFORE any plots are generated.
    apply_massive_plot_style()
    try:
        # Only defined after our earlier additions; safe guard.
        apply_external_legends
    except NameError:
        pass
    else:
        apply_external_legends()
    main()
    # Example (post-run) usage of interactive merge tool (uncomment and adapt):
    # After running a layer analysis and obtaining reduced_activations & sampled_boards, you can do:
    # l0_labels, l0_analysis, sub_map = run_hybrid_hierarchical_agglomerative_analysis(
    #     reduced_activations, sampled_boards, output_dir, layer, style, return_components=True
    # )
    # create_interactive_hybrid_merge_tool(
    #     reduced_activations=reduced_activations,
    #     l0_labels=l0_labels,
    #     l0_analysis_results=l0_analysis,
    #     l0_to_sub_clusters_map=sub_map,
    #     final_boards_list=sampled_boards,
    #     output_dir=output_dir,
    #     layer=layer,
    #     prompt_style=style
    # )

