import os
import pickle
import numpy as np
import networkx as nx
import igraph as ig
import leidenalg
import community as community_louvain
import torch
import matplotlib.pyplot as plt
from glob import glob


def louvain_clusters(S, min_sub_count=1, resolution=1):
    # NOTE: This treats graphs as UNDIRECTED by symmetrizing the input.
    # Filter weights and create graph
    S_masked = np.where(S >= min_sub_count, S, 0)
    # Explicitly symmetrize
    S_sym = S_masked + S_masked.T
    g = nx.from_numpy_array(S_sym)

    # Run clustering
    partition = community_louvain.best_partition(
        g,
        weight='weight',
        resolution=resolution,
        random_state=27
    )

    # Convert to array, 'partition' is a dict {node_id: community_id} covering all nodes
    labels = np.array([partition[i] for i in range(S.shape[0])])

    # Reindex to 0...K-1
    _, unique_labels = np.unique(labels, return_inverse=True)
    return unique_labels


def leiden_clusters(S, min_sub_count=1, resolution=1):
    # NOTE: This correctly treats graphs as DIRECTED.
    # Filter weights and create graph
    S_masked = np.where(S >= min_sub_count, S, 0)
    g = ig.Graph.Weighted_Adjacency(S_masked.tolist(), mode='directed')

    # Run clustering
    partition = leidenalg.find_partition(
        g,
        leidenalg.RBConfigurationVertexPartition,
        weights=g.es['weight'],
        resolution_parameter=resolution,
        seed=27
    )

    # Convert to array, partition.membership is a list with index=node, value=community
    labels = np.array(partition.membership)

    # Reindex to 0...K-1
    _, unique_labels = np.unique(labels, return_inverse=True)
    return unique_labels


def load_token_data(data_dir, num_channels=8):
    """Loads tokens into (N, C) tensors."""
    if not os.path.exists(data_dir):
        print(f"Warning: {data_dir} not found.")
        return None, None

    orig_list, rt_list = [], []
    sample_dirs = sorted(glob(os.path.join(data_dir, "*")))
    print(f"Loading token samples from {data_dir} ({len(sample_dirs)} candidate folders)...")

    pairs_processed = 0
    for d in sample_dirs:
        if not os.path.isdir(d):
            continue

        o_path = os.path.join(d, "orig_tokens.txt")
        if not os.path.exists(o_path):
            continue

        # Read original tokens (take up to num_channels lines)
        with open(o_path, "r") as f:
            o_lines = [l for l in f if l.strip()]
        if len(o_lines) == 0:
            continue
        o = [np.fromstring(l, sep=" ", dtype=np.int64) for l in o_lines[:num_channels]]

        # Collect roundtrip variants: explicit roundtrip_tokens.txt first, then any attacked_tokens_*.txt
        variants = []
        r_path = os.path.join(d, "roundtrip_tokens.txt")
        if os.path.exists(r_path):
            with open(r_path, "r") as f:
                r_lines = [l for l in f if l.strip()]
            if len(r_lines) > 0:
                variants.append([np.fromstring(l, sep=" ", dtype=np.int64) for l in r_lines[:num_channels]])

        attacked_paths = sorted(glob(os.path.join(d, "attacked_tokens_*.txt")))
        for atk in attacked_paths:
            with open(atk, "r") as f:
                atk_lines = [l for l in f if l.strip()]
            if len(atk_lines) == 0:
                continue
            variants.append([np.fromstring(l, sep=" ", dtype=np.int64) for l in atk_lines[:num_channels]])

        if not variants:
            # No roundtrip or attacked variants in this folder
            continue

        # For each variant, form a (T, C) block and append
        for r in variants:
            # Ensure both orig and r have at least one channel and non-zero length
            if len(o) == 0 or len(r) == 0:
                continue
            if any(x.size == 0 for x in o) or any(x.size == 0 for x in r):
                continue

            T = min(o[0].size, r[0].size)
            if T == 0:
                continue

            orig_block = np.stack([x[:T] for x in o], axis=1)
            rt_block = np.stack([x[:T] for x in r], axis=1)

            orig_list.append(orig_block)
            rt_list.append(rt_block)
            pairs_processed += 1

    if not orig_list:
        print(f"No valid token pairs found in {data_dir}.")
        return None, None

    x = torch.from_numpy(np.concatenate(orig_list, axis=0)).long()
    y = torch.from_numpy(np.concatenate(rt_list, axis=0)).long()
    print(f"Loaded {pairs_processed} orig/rt pairs -> tokens shapes: {x.shape}, {y.shape}")
    return x, y


def get_cluster_stats(labels, usage_counts=None):
    """
    Returns:
    - K: Number of clusters
    - S: Number of singletons
    - M: Largest cluster size
    - K_eff: Effective vocabulary size (2^Entropy) based on usage_counts
    """
    unique, counts = np.unique(labels, return_counts=True)
    K = len(unique)
    singletons = np.sum(counts == 1)
    max_size = np.max(counts) if len(counts) > 0 else 0

    K_eff = K  # Default if no usage info
    if usage_counts is not None:
        # Aggregate usage per cluster
        cluster_usage = np.zeros(K)
        np.add.at(cluster_usage, labels, usage_counts)

        total = cluster_usage.sum()
        if total > 0:
            probs = cluster_usage / total
            probs = probs[probs > 0]
            entropy = -np.sum(probs * np.log2(probs))
            K_eff = 2 ** entropy

    return K, singletons, max_size, K_eff


def build_confusion_and_usage(train_dirs, num_channels=8, vocab_size=2048):
    """
    Given a list of training directories, build confusion matrices (C) and usage counts (U).
    Returns:
    - conf_matrices: (num_channels, V, V) int array
    - usage_counts: (num_channels, V) int array
    """
    conf_matrices = np.zeros((num_channels, vocab_size, vocab_size), dtype=np.int64)
    usage_counts = np.zeros((num_channels, vocab_size), dtype=np.int64)

    for tdir in train_dirs:
        if not os.path.exists(tdir):
            print(f"Warning: train dir {tdir} not found, skipping.")
            continue
        for name in sorted(os.listdir(tdir)):
            sample_path = os.path.join(tdir, name)
            if not os.path.isdir(sample_path):
                continue

            orig_path = os.path.join(sample_path, "orig_tokens.txt")
            rt_path = os.path.join(sample_path, "roundtrip_tokens.txt")
            if not (os.path.exists(orig_path) and os.path.exists(rt_path)):
                continue

            with open(orig_path) as f:
                orig = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)
            with open(rt_path) as f:
                rt = np.array([[int(x) for x in l.split()] for l in f if l.strip()], dtype=np.int64)

            if orig.shape != rt.shape:
                continue

            for ch in range(min(num_channels, orig.shape[0])):
                row_orig, row_rt = orig[ch], rt[ch]
                np.add.at(usage_counts[ch], row_orig, 1)
                mask = row_orig != row_rt
                if np.any(mask):
                    np.add.at(conf_matrices[ch], (row_orig[mask], row_rt[mask]), 1)

    return conf_matrices, usage_counts


# ==========================================
# Main Execution
# ==========================================
if __name__ == "__main__":
    # Config - replace paths as needed
    SAMPLES_TRAIN_DIR = "/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_encodec_32/samples"
    SAMPLES_TRAIN_AUG_DIR = "/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_encodec_32_aug/samples"
    SAMPLES_DEV_DIR = "/home/wmar/wmar_audio/outputs/rcc_reconstructions/original_encodec_32_testset/samples"

    MATRIX_DIR = "/home/wmar/wmar_audio/outputs/confusion/matrices_encodec"
    PLOT_DIR = "/home/wmar/wmar_audio/outputs/confusion/plots_ablation_encodec"
    CLUSTERS_DIR = "/home/wmar/wmar_audio/models/embeddings/new_clusterings_encodec"

    os.makedirs(MATRIX_DIR, exist_ok=True)
    os.makedirs(PLOT_DIR, exist_ok=True)
    os.makedirs(CLUSTERS_DIR, exist_ok=True)

    VOCAB_SIZE = 2048
    NUM_CHANNELS = 8
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Build confusion/usage for:
    # A) Train only
    # B) Train + Aug (if aug exists)
    print("Building confusion matrices and usage counts for Train-only and Train+Aug (if present)...")
    train_only_dirs = [SAMPLES_TRAIN_DIR]
    combined_dirs = [SAMPLES_TRAIN_DIR]
    if os.path.exists(SAMPLES_TRAIN_AUG_DIR):
        combined_dirs.append(SAMPLES_TRAIN_AUG_DIR)
        print(f"Including augmented training samples from: {SAMPLES_TRAIN_AUG_DIR}")
    else:
        print(f"Augmented training dir not found at {SAMPLES_TRAIN_AUG_DIR}; combined == train-only.")

    conf_trainonly, usage_trainonly = build_confusion_and_usage(train_only_dirs, NUM_CHANNELS, VOCAB_SIZE)
    conf_combined, usage_combined = build_confusion_and_usage(combined_dirs, NUM_CHANNELS, VOCAB_SIZE)

    # Optionally save matrices
    for ch in range(NUM_CHANNELS):
        np.save(os.path.join(MATRIX_DIR, f"confusion_trainonly_{ch}.npy"), conf_trainonly[ch])
        np.save(os.path.join(MATRIX_DIR, f"confusion_combined_{ch}.npy"), conf_combined[ch])

    # Load token data for train-only, train-aug, and combined
    print("Loading token datasets...")
    x_train_only, y_train_only = load_token_data(SAMPLES_TRAIN_DIR, NUM_CHANNELS)

    x_train_aug, y_train_aug = (None, None)
    if os.path.exists(SAMPLES_TRAIN_AUG_DIR):
        x_train_aug, y_train_aug = load_token_data(SAMPLES_TRAIN_AUG_DIR, NUM_CHANNELS)

    # Compose combined train tokens if aug present
    if x_train_only is None and x_train_aug is None:
        raise ValueError("No training data found in either train or augmented train dirs.")
    elif x_train_only is None:
        x_train_combined, y_train_combined = x_train_aug, y_train_aug
    elif x_train_aug is None:
        x_train_combined, y_train_combined = x_train_only, y_train_only
    else:
        x_train_combined = torch.cat([x_train_only, x_train_aug], dim=0)
        y_train_combined = torch.cat([y_train_only, y_train_aug], dim=0)

    x_dev, y_dev = load_token_data(SAMPLES_DEV_DIR, NUM_CHANNELS)
    if x_train_only is not None:
        x_train_only, y_train_only = x_train_only.to(DEVICE), y_train_only.to(DEVICE)
    if x_train_combined is not None:
        x_train_combined, y_train_combined = x_train_combined.to(DEVICE), y_train_combined.to(DEVICE)
    if x_dev is not None:
        x_dev, y_dev = x_dev.to(DEVICE), y_dev.to(DEVICE)

    # Ablation settings
    # CHANNELS_NAMES = ['rvq_first_0', 'rvq_rest_0', 'rvq_rest_1', 'rvq_rest_2', 'rvq_rest_3', 'rvq_rest_4', 'rvq_rest_5', 'rvq_rest_6']
    CHANNELS_NAMES = [0, 1, 2, 3]

    COUNTS = [1, 5, 10, 25, 50]
    RESOLUTIONS = [0.2, 0.5, 0.8, 1.0, 1.2, 1.5]
    DEFAULT_RES = 1.0
    DEFAULT_CNT = 1

    # Helper to scale marker sizes by K
    def marker_sizes_from_Ks(Ks, vocab_size=VOCAB_SIZE, min_s=30, max_s=400):
        Ks = np.array(Ks, dtype=float)
        if Ks.size == 0:
            return np.array([])
        # Linear scale from min(K) .. max(K) -> min_s .. max_s, but anchored to vocab_size to cap scale
        scaled = min_s + (Ks / vocab_size) * (max_s - min_s)
        return np.clip(scaled, min_s, max_s)

    # Processing loop per channel
    for ch_idx, ch_name in enumerate(CHANNELS_NAMES):
        print(f"\nProcessing Channel {ch_idx}: {ch_name}")

        # Prepare per-evaluation containers:
        evals = [
            {
                "tag": "TrainOnly",
                "conf_matrix": conf_trainonly[ch_idx],
                "usage": usage_trainonly[ch_idx],
                "x_train": x_train_only,
                "y_train": y_train_only,
                "pkl_out": os.path.join(CLUSTERS_DIR, f"mimi_leiden_{ch_name}_clusterings_trainonly.pkl"),
                "color": "C0",
            },
            {
                "tag": "Train+Aug",
                "conf_matrix": conf_combined[ch_idx],
                "usage": usage_combined[ch_idx],
                "x_train": x_train_combined,
                "y_train": y_train_combined,
                "pkl_out": os.path.join(CLUSTERS_DIR, f"mimi_leiden_{ch_name}_clusterings_combined.pkl"),
                "color": "C1",
            },
        ]

        # Baselines: compute per-eval r_train from that eval's train tokens
        for e in evals:
            if e["x_train"] is None:
                e["r_train"] = 0.0
            else:
                e["r_train"] = (e["x_train"][:, ch_idx] == e["y_train"][:, ch_idx]).float().mean().item()
        # dev baseline (same for both)
        r_dev = (x_dev[:, ch_idx] == y_dev[:, ch_idx]).float().mean().item() if x_dev is not None else 0.0

        # Experiment runner uses particular confusion matrix & usage
        def run_experiment_for_eval(conf_matrix, usage, x_train_tokens, y_train_tokens, cnt, res):
            labels = louvain_clusters(conf_matrix, min_sub_count=cnt, resolution=res)
            K, sgl, max_s, K_eff = get_cluster_stats(labels, usage)

            cmap = torch.tensor(labels, device=DEVICE)
            # If train tokens for this eval are present, compute train match; otherwise zero
            if x_train_tokens is not None:
                eta_t = (cmap[x_train_tokens[:, ch_idx]] == cmap[y_train_tokens[:, ch_idx]]).float().mean().item()
            else:
                eta_t = 0.0

            # Dev evaluation uses dev set
            if x_dev is not None:
                eta_d = (cmap[x_dev[:, ch_idx]] == cmap[y_dev[:, ch_idx]]).float().mean().item()
            else:
                eta_d = 0.0

            return labels, eta_t, eta_d, (K, sgl, max_s, K_eff)

        # -------------------------------------------------
        # Ablation 1: Min Count (Fixed Resolution) - both evals on same plot
        # -------------------------------------------------
        plt.figure(figsize=(8, 6), dpi=150)
        for e in evals:
            x_vals, y_train_vals, y_dev_vals, stats_list, Ks = [], [], [], [], []
            pkl_leiden = {}
            for cnt in COUNTS:
                lbls, et, ed, stats = run_experiment_for_eval(e["conf_matrix"], e["usage"], e["x_train"], e["y_train"], cnt, DEFAULT_RES)
                pkl_leiden[cnt] = lbls
                K, S, M, Keff = stats
                x_vals.append(Keff)
                y_train_vals.append(et)
                y_dev_vals.append(ed)
                stats_list.append((cnt, K, M))
                Ks.append(K)

            # Plot lines
            plt.plot(x_vals, y_train_vals, '-o', color=e["color"], label=f"{e['tag']} (Train)")
            plt.plot(x_vals, y_dev_vals, '--o', color=e["color"], alpha=0.6, label=f"{e['tag']} (Dev)")
            # Scatter with marker size proportional to K (so marker sizes vary by K)
            sizes = marker_sizes_from_Ks(Ks)
            plt.scatter(x_vals, y_train_vals, s=sizes, color=e["color"], alpha=0.75, edgecolor='k', linewidth=0.3)

            # Annotate only K and M (no p to keep plot clean)
            for i, (cnt, K, M) in enumerate(stats_list):
                plt.annotate(f"K={K}\nM={M}", (x_vals[i], y_train_vals[i]),
                             xytext=(0, 10), textcoords='offset points', fontsize=8, ha='center')

            # Save pickle per evaluation (min_count results)
            with open(e["pkl_out"], "wb") as f:
                pickle.dump(pkl_leiden, f)

            # Also store baseline lines for legend later
            plt.axhline(e["r_train"], color=e["color"], linestyle=':', alpha=0.25, label=f"{e['tag']} base r_train={e['r_train']:.3f}")

        # Single dev baseline line
        if x_dev is not None:
            plt.axhline(r_dev, color='k', linestyle='--', alpha=0.3, label=f'Base Dev r={r_dev:.3f}')

        plt.title(f"Pareto: Min Count Ablation (Ch {ch_name})\nRes={DEFAULT_RES}")
        plt.xlabel("Effective Vocabulary Size ($K_{eff}$)")
        plt.ylabel(r"Cluster Match Rate ($\eta_{ctx}$)")
        # plt.xscale("log")
        plt.grid(True, which="both", alpha=0.3)
        plt.legend(loc='best', fontsize=8)
        plt.tight_layout()
        plt.savefig(os.path.join(PLOT_DIR, f"louvain_ablation_mincount_{ch_name}.png"), dpi=150)
        plt.close()

        # -------------------------------------------------
        # Ablation 2: Resolution (Fixed Min Count) - both evals on same plot
        # -------------------------------------------------
        plt.figure(figsize=(8, 6), dpi=150)
        for e in evals:
            x_vals, y_train_vals, y_dev_vals, stats_list, Ks = [], [], [], [], []
            pkl_leiden = {}
            for res in RESOLUTIONS:
                lbls, et, ed, stats = run_experiment_for_eval(e["conf_matrix"], e["usage"], e["x_train"], e["y_train"], DEFAULT_CNT, res)
                pkl_leiden[res] = lbls
                K, S, M, Keff = stats
                x_vals.append(Keff)
                y_train_vals.append(et)
                y_dev_vals.append(ed)
                stats_list.append((res, K, M))
                Ks.append(K)

            # Plot lines + scatter sizes
            plt.plot(x_vals, y_train_vals, '-o', color=e["color"], label=f"{e['tag']} (Train)")
            plt.plot(x_vals, y_dev_vals, '--o', color=e["color"], alpha=0.6, label=f"{e['tag']} (Dev)")
            sizes = marker_sizes_from_Ks(Ks)
            plt.scatter(x_vals, y_train_vals, s=sizes, color=e["color"], alpha=0.75, edgecolor='k', linewidth=0.3)

            # Annotate only K and M
            for i, (res, K, M) in enumerate(stats_list):
                plt.annotate(f"K={K}\nM={M}", (x_vals[i], y_train_vals[i]),
                             xytext=(0, 10), textcoords='offset points', fontsize=8, ha='center')

            # # Save pickle per evaluation (min_count results)
            # with open(e["pkl_out"].replace(".pkl", "_res.pkl"), "wb") as f:
            #     pickle.dump(pkl_leiden, f)

        # Baselines
        for e in evals:
            plt.axhline(e["r_train"], color=e["color"], linestyle=':', alpha=0.25)

        if x_dev is not None:
            plt.axhline(r_dev, color='k', linestyle='--', alpha=0.3, label=f'Base Dev r={r_dev:.3f}')

        plt.title(f"Pareto: Resolution Ablation (Ch {ch_name})\nMinCount={DEFAULT_CNT}")
        plt.xlabel("Effective Vocabulary Size ($K_{eff}$)")
        plt.ylabel(r"Cluster Match Rate ($\eta_{ctx}$)")
        # plt.xscale("log")
        plt.grid(True, which="both", alpha=0.3)
        plt.legend(loc='best', fontsize=8)
        plt.tight_layout()
        plt.savefig(os.path.join(PLOT_DIR, f"louvain_ablation_resolution_{ch_name}.png"), dpi=150)
        plt.close()

    print("Done.")
