import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
from collections import defaultdict


def load_token_txt(path):
    """
    Returns np.ndarray of shape [K, T]
    """
    with open(path, "r") as f:
        lines = f.readlines()

    tokens = []
    for line in lines:
        line = line.strip()
        if line == "":
            tokens.append([])
        else:
            tokens.append([int(x) for x in line.split()])

    return np.array(tokens, dtype=np.int64)


def load_all_pairs(samples_dir):
    """
    Loads all (orig, roundtrip) Mimi token pairs.
    """
    pairs = []

    for name in sorted(os.listdir(samples_dir)):
        sample_dir = os.path.join(samples_dir, name)
        if not os.path.isdir(sample_dir):
            continue

        orig_path = os.path.join(sample_dir, "orig_tokens.txt")
        rt_path = os.path.join(sample_dir, "roundtrip_tokens.txt")

        if not (os.path.exists(orig_path) and os.path.exists(rt_path)):
            continue

        orig = load_token_txt(orig_path)
        rt = load_token_txt(rt_path)

        assert orig.shape == rt.shape, f"Shape mismatch in {name}"
        pairs.append((orig, rt))

    return pairs


def compute_substitution_confusion(pairs):
    """
    Returns:
        conf[ch][(orig_token, rt_token)] = count
    """
    conf = defaultdict(lambda: defaultdict(int))

    for orig, rt in pairs:
        K, T = orig.shape
        for ch in range(K):
            for t in range(T):
                a = orig[ch, t]
                b = rt[ch, t]
                if a != b:
                    conf[ch][(a, b)] += 1

    return conf


def plot_sparse_confusion(conf_ch, channel_id, out_path, min_count=1):
    """
    Drop-in replacement for visualization:
    - Uses dense heatmap with LogNorm
    - Only masks zero counts
    - Keeps token IDs as axes (no remapping)
    """
    if not conf_ch:
        print(f"Skipping channel {channel_id}: no substitutions")
        return

    # Extract all token IDs
    orig_tokens = [i for (i, j) in conf_ch.keys()]
    rt_tokens = [j for (i, j) in conf_ch.keys()]

    if not orig_tokens or not rt_tokens:
        print(f"Skipping channel {channel_id}: empty substitutions")
        return

    # Define axis ranges
    orig_min, orig_max = min(orig_tokens), max(orig_tokens)
    rt_min, rt_max = min(rt_tokens), max(rt_tokens)

    # Initialize matrix
    cm = np.zeros((orig_max - orig_min + 1, rt_max - rt_min + 1), dtype=np.int64)

    # Fill matrix
    for (i, j), c in conf_ch.items():
        if c < min_count:
            continue
        cm[i - orig_min, j - rt_min] = c

    if cm.sum() == 0:
        print(f"Skipping channel {channel_id}: all counts below min_count")
        return

    # Mask zeros for LogNorm
    masked = np.ma.masked_where(cm == 0, cm)

    plt.figure(figsize=(10, 8))
    plt.imshow(
        masked,
        cmap="magma",
        norm=LogNorm(vmin=1, vmax=masked.max()),
        aspect="auto",
        interpolation="nearest"
    )
    plt.colorbar(label="Substitution count (log scale)")
    plt.xlabel("Re-encoded token ID")
    plt.ylabel("Original token ID")
    plt.title(f"Mimi token substitutions — channel {channel_id}")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# ============================
# Entry point
# ============================

if __name__ == "__main__":
    samples_dir = "outputs/rcc_reconstructions/original_mimi/samples"
    matrix_dir = "outputs/confusion/matrices"
    plot_dir = "outputs/confusion/plots"
    os.makedirs(matrix_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)

    pairs = load_all_pairs(samples_dir)
    print(f"Loaded {len(pairs)} samples")

    conf = compute_substitution_confusion(pairs)

    for ch in sorted(conf.keys()):
        # Convert dict[(orig, rt)] -> 2D np array
        conf_ch = conf[ch]
        if not conf_ch:
            continue

        # Determine matrix size
        orig_tokens = [i for (i, j) in conf_ch.keys()]
        rt_tokens = [j for (i, j) in conf_ch.keys()]
        orig_min, orig_max = min(orig_tokens), max(orig_tokens)
        rt_min, rt_max = min(rt_tokens), max(rt_tokens)
        cm = np.zeros((orig_max - orig_min + 1, rt_max - rt_min + 1), dtype=np.int64)

        # Fill matrix
        for (i, j), c in conf_ch.items():
            cm[i - orig_min, j - rt_min] = c

        # Save matrix as np array
        np.save(os.path.join(matrix_dir, f"confusion_{ch}.npy"), cm)

        # Plot as before
        plot_sparse_confusion(conf_ch, ch, os.path.join(plot_dir, f"confusion_{ch}.png"))

    print("Done.")
