"""apply_manual_merges.py

Standalone utility to apply manually curated sub-cluster merges (produced by
create_interactive_hybrid_merge_tool) and regenerate a merged visualization +
summary WITHOUT re-running the expensive full_grid_probe pipeline.

Inputs (required / optional):
  --embeddings <PATH>            : .npz file containing reduced_activations array (key: reduced_activations OR reduced)
  --subcluster-meta <PATH>       : JSON produced by *_subcluster_metadata_*.json
  --merges-json <PATH>           : JSON downloaded from interactive tool (keys: merges: [{new_label, members, size_sum}])
  --output-dir <DIR>             : Directory to save outputs (created if missing)
  --layer <INT>                  : Layer index (for titles / provenance)
  --prompt-style <STR>           : Prompt style string (for filenames)
  --background-alpha <FLOAT>     : Alpha for unassigned / noise points (default 0.12)
  --point-size <INT>             : Marker size for individual points (default 6)
  --summary-only                 : If passed, skip plot and only produce merged summary JSON.
  --save-csv                     : Also export CSV listing each point -> merged_label.
  --random-seed <INT>            : Seed for consistent color assignment.

Embeddings File Expectations:
  Generated earlier (e.g., t-SNE result). If your cached file used np.savez_compressed
  with a different key name, we fallback to first array key.

Subcluster Metadata Expectations:
  List of entries with keys: ['global_id','parent_l0','size','label','centroid','indices'].

Merges JSON Expectations:
  { "merges": [ { "new_label": "Some Concept", "members": [0,2,5], "size_sum": 123 }, ... ] }
  Unknown / leftover subcluster IDs will be auto-assigned their original label.

Outputs:
  merged_clusters_{prompt_style}_layer{layer}.png  (unless --summary-only)
  merged_clusters_{prompt_style}_layer{layer}.json  (summary of merged groups)
  merged_point_assignments_{prompt_style}_layer{layer}.json  (point -> merged label)
  merged_point_assignments_{prompt_style}_layer{layer}.csv    (optional if --save-csv)

Merging Logic:
  1. For every merge group, we form a union of all indices of its member subclusters.
  2. A point can belong to at most one subcluster in the original metadata; if a point
     appears in >1 member list (shouldn't normally happen), first merge group wins.
  3. Remaining unmerged subclusters become their own groups with their original label.
  4. Points not belonging to ANY subcluster (value -1 / noise earlier) are plotted as light grey.

Example:
  python apply_manual_merges.py \
      --embeddings viz_cache/qwen_layer8_text_instruction_tsne.npz \
      --subcluster-meta visualizations/Qwen/layer_8/hybrid_merge_subcluster_metadata_text_instruction_layer8.json \
      --merges-json manual/merged_subclusters.json \
      --output-dir merged_outputs/qwen_l8_text \
      --layer 8 --prompt-style text_instruction

"""
from __future__ import annotations
import argparse
import os
import json
import random
from typing import Dict, List, Any
import numpy as np
import matplotlib.pyplot as plt

try:
    from adjustText import adjust_text  # Optional; used for label declutter
    _ADJUST_AVAILABLE = True
except Exception:
    _ADJUST_AVAILABLE = False


def load_embeddings(path: str) -> np.ndarray:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Embeddings file not found: {path}")
    data = np.load(path)
    # Accept common key names
    for key in ['reduced_activations', 'reduced', 'arr_0']:
        if key in data:
            return data[key]
    # Fallback: take first array-like
    first_key = list(data.keys())[0]
    return data[first_key]


def load_json(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"JSON file not found: {path}")
    with open(path, 'r') as f:
        return json.load(f)


def apply_merges(sub_meta: List[Dict[str, Any]], merges_spec: Dict[str, Any]):
    """Apply merge instructions to sub-cluster metadata.

    Returns:
        merged_groups: List[{ 'new_label', 'member_subclusters', 'size', 'indices' }]
        point_label_map: np.ndarray[int] (len = total points, cluster id enumerated)
        merged_label_strings: Dict[int, str] mapping enumerated cluster id -> label
    """
    # Collect all possible subcluster IDs
    id_to_entry = {e['global_id']: e for e in sub_meta}
    all_ids = set(id_to_entry.keys())

    merges = merges_spec.get('merges', [])
    used_ids = set()
    merged_groups = []

    # Build merged groups from user spec
    for m in merges:
        members = [mid for mid in m.get('members', []) if mid in id_to_entry]
        if not members:
            continue
        indices_union = []
        for mid in members:
            if mid in used_ids:
                # Already consumed by a previous merge; skip to enforce first-wins policy
                continue
            used_ids.add(mid)
            indices_union.extend(id_to_entry[mid]['indices'])
        if not indices_union:
            continue
        merged_groups.append({
            'new_label': m.get('new_label', f"Merged_{len(merged_groups)}"),
            'member_subclusters': members,
            'size': len(indices_union),
            'indices': sorted(indices_union)
        })

    # Remaining individual subclusters (not merged)
    remaining_ids = sorted(all_ids - used_ids)
    for rid in remaining_ids:
        entry = id_to_entry[rid]
        merged_groups.append({
            'new_label': entry['label'],
            'member_subclusters': [rid],
            'size': entry['size'],
            'indices': sorted(entry['indices'])
        })

    # Enumerate final merged cluster ids
    merged_groups = sorted(merged_groups, key=lambda g: g['new_label'])  # deterministic ordering

    # Build point->cluster id assignment
    # Determine total number of points from max index + 1
    max_point_index = max(idx for g in merged_groups for idx in g['indices'])
    point_label_map = np.full(max_point_index + 1, fill_value=-1, dtype=int)

    for cid, group in enumerate(merged_groups):
        for idx in group['indices']:
            if 0 <= idx < len(point_label_map):
                point_label_map[idx] = cid

    merged_label_strings = {cid: g['new_label'] for cid, g in enumerate(merged_groups)}
    return merged_groups, point_label_map, merged_label_strings


def plot_merged_clusters(
    reduced_activations: np.ndarray,
    point_label_map: np.ndarray,
    merged_label_strings: Dict[int, str],
    output_path: str,
    background_alpha: float = 0.12,
    point_size: int = 6,
    title: str = "Merged Clusters",
    merged_cluster_ids: set[int] | None = None,
    merged_cluster_label_fontsize: int = 16,
    singleton_cluster_label_fontsize: int = 12,
    title_fontsize: int = 24,
    axis_label_fontsize: int = 18,
    tick_fontsize: int = 14
):
    unique_cids = sorted([cid for cid in np.unique(point_label_map) if cid >= 0])
    # Assign colors deterministically
    random.seed(42)
    palette = plt.cm.get_cmap('tab20', max(20, len(unique_cids)))

    plt.figure(figsize=(20, 16))

    # Background (noise/unassigned)
    noise_mask = (point_label_map < 0)
    if noise_mask.any():
        plt.scatter(
            reduced_activations[noise_mask, 0],
            reduced_activations[noise_mask, 1],
            c='lightgrey', s=point_size, alpha=background_alpha, label='Unassigned'
        )

    texts = []
    for i, cid in enumerate(unique_cids):
        mask = (point_label_map == cid)
        pts = reduced_activations[mask]
        if len(pts) == 0:
            continue
        color = palette(i % palette.N)
        plt.scatter(pts[:, 0], pts[:, 1], c=[color], s=point_size, alpha=0.75)
        centroid = pts.mean(axis=0)
        label = merged_label_strings.get(cid, f"Cluster {cid}")
        is_merged = (merged_cluster_ids is not None and cid in merged_cluster_ids)
        fs = merged_cluster_label_fontsize if is_merged else singleton_cluster_label_fontsize
        texts.append(plt.text(
            centroid[0], centroid[1], label,
            ha='center', va='center', fontsize=fs,
            bbox=dict(boxstyle='round,pad=0.45', fc='white', ec='black', alpha=0.85, lw=0.8)
        ))

    if _ADJUST_AVAILABLE and texts:
        try:
            adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black', lw=0.5, alpha=0.6))
        except Exception:
            pass

    # plt.title(title, fontsize=title_fontsize)
    plt.xlabel('Reduced Dim 1', fontsize=axis_label_fontsize)
    plt.ylabel('Reduced Dim 2', fontsize=axis_label_fontsize)
    plt.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    plt.grid(True, alpha=0.25)
    plt.savefig(output_path)
    plt.close()
    print(f"Saved merged cluster plot to {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Apply manual sub-cluster merges and visualize result.")
    parser.add_argument('--embeddings', required=True, help='Path to .npz reduced activations file')
    parser.add_argument('--subcluster-meta', required=True, help='Path to sub-cluster metadata JSON')
    parser.add_argument('--merges-json', required=True, help='Path to manual merges JSON produced by interactive tool')
    parser.add_argument('--output-dir', required=True, help='Directory to save outputs')
    parser.add_argument('--layer', type=int, required=True, help='Layer index (for metadata/filenames)')
    parser.add_argument('--prompt-style', type=str, required=True, help='Prompt style label')
    parser.add_argument('--background-alpha', type=float, default=0.12)
    parser.add_argument('--point-size', type=int, default=6)
    parser.add_argument('--summary-only', action='store_true')
    parser.add_argument('--save-csv', action='store_true')
    parser.add_argument('--random-seed', type=int, default=42)
    parser.add_argument('--cluster-label-fontsize', type=int, default=16, help='(DEPRECATED alias) Font size for merged cluster labels; use --merged-cluster-label-fontsize')
    parser.add_argument('--merged-cluster-label-fontsize', type=int, default=None, help='Font size for merged (multi-subcluster) labels')
    parser.add_argument('--singleton-cluster-label-fontsize', type=int, default=12, help='Font size for singleton (unmerged) cluster labels')
    parser.add_argument('--title-fontsize', type=int, default=24, help='Font size for plot title')
    parser.add_argument('--axis-fontsize', type=int, default=18, help='Font size for axis labels')
    parser.add_argument('--tick-fontsize', type=int, default=14, help='Font size for tick labels')
    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)

    os.makedirs(args.output_dir, exist_ok=True)

    reduced = load_embeddings(args.embeddings)
    sub_meta = load_json(args.subcluster_meta)
    merges_spec = load_json(args.merges_json)

    merged_groups, point_label_map, merged_label_strings = apply_merges(sub_meta, merges_spec)

    # Determine which cluster ids correspond to merged (multi-member) groups
    merged_cluster_ids = {cid for cid, g in enumerate(merged_groups) if len(g.get('member_subclusters', [])) > 1}

    # Backwards compatibility: if new merged font size not provided, fall back to previous --cluster-label-fontsize
    merged_fontsize = args.merged_cluster_label_fontsize if args.merged_cluster_label_fontsize is not None else args.cluster_label_fontsize

    # --- Save summaries ---
    summary_json_path = os.path.join(
        args.output_dir,
        f"merged_clusters_{args.prompt_style}_layer{args.layer}.json"
    )
    with open(summary_json_path, 'w') as f:
        json.dump({
            'layer': args.layer,
            'prompt_style': args.prompt_style,
            'merged_groups': merged_groups
        }, f, indent=2)
    print(f"Saved merged group summary to {summary_json_path}")

    point_map_json_path = os.path.join(
        args.output_dir,
        f"merged_point_assignments_{args.prompt_style}_layer{args.layer}.json"
    )
    with open(point_map_json_path, 'w') as f:
        json.dump({
            'layer': args.layer,
            'prompt_style': args.prompt_style,
            'point_cluster_ids': point_label_map.tolist(),
            'label_lookup': merged_label_strings
        }, f, indent=2)
    print(f"Saved point assignment map to {point_map_json_path}")

    if args.save_csv:
        try:
            import csv
            csv_path = os.path.join(
                args.output_dir,
                f"merged_point_assignments_{args.prompt_style}_layer{args.layer}.csv"
            )
            with open(csv_path, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(['point_index', 'cluster_id', 'cluster_label'])
                for idx, cid in enumerate(point_label_map):
                    writer.writerow([idx, cid, merged_label_strings.get(cid, '') if cid >=0 else ''])
            print(f"Saved CSV assignments to {csv_path}")
        except Exception as e:
            print(f"Failed to write CSV: {e}")

    if not args.summary_only:
        plot_path = os.path.join(
            args.output_dir,
            f"merged_clusters_{args.prompt_style}_layer{args.layer}.png"
        )
        title = f"Merged Concept Clusters (Layer {args.layer}, Style: {args.prompt_style})"
        plot_merged_clusters(
            reduced,
            point_label_map,
            merged_label_strings,
            plot_path,
            background_alpha=args.background_alpha,
            point_size=args.point_size,
            # title=title,
            merged_cluster_ids=merged_cluster_ids,
            merged_cluster_label_fontsize=merged_fontsize,
            singleton_cluster_label_fontsize=args.singleton_cluster_label_fontsize,
            title_fontsize=args.title_fontsize,
            axis_label_fontsize=args.axis_fontsize,
            tick_fontsize=args.tick_fontsize
        )

if __name__ == '__main__':
    main()
