import os
import numpy as np
import matplotlib.pyplot as plt


def load_token_txt(path):
    with open(path, "r") as f:
        lines = f.readlines()

    tokens = []
    for line in lines:
        line = line.strip()
        if line == "":
            tokens.append([])
        else:
            tokens.append([int(x) for x in line.split()])

    return np.array(tokens, dtype=np.int64)  # [K, T]


def load_all_samples(samples_dir):
    samples = []

    for name in sorted(os.listdir(samples_dir)):
        sample_dir = os.path.join(samples_dir, name)
        if not os.path.isdir(sample_dir):
            continue

        path = os.path.join(sample_dir, "orig_tokens.txt")
        if not os.path.exists(path):
            continue

        samples.append(load_token_txt(path))

    return samples


if __name__ == "__main__":
    samples_dir = "outputs/rcc_reconstructions/original_mimi/samples"
    vocab_size = 2048

    # Channel 0: unseen 341/2048
    # Channel 1: unseen 155/2048
    # Channel 2: unseen 47/2048
    # Channel 3: unseen 15/2048
    # Channel 4: unseen 12/2048
    # Channel 5: unseen 12/2048
    # Channel 6: unseen 3/2048
    # Channel 7: unseen 2/2048

    samples = load_all_samples(samples_dir)
    K = samples[0].shape[0]

    for ch in range(K):
        ch_tokens = np.concatenate([s[ch].ravel() for s in samples])
        counts = np.bincount(ch_tokens, minlength=vocab_size)

        sorted_counts = np.sort(counts)[::-1]
        unseen = np.sum(sorted_counts == 0)

        print(f"Channel {ch}: unseen {unseen}/{vocab_size}")

        plt.figure(figsize=(6, 4))
        plt.plot(sorted_counts)
        plt.yscale("log")
        plt.xlabel("Token rank")
        plt.ylabel("Count (log)")
        plt.title(f"Sorted token usage — channel {ch}")
        plt.tight_layout()
        plt.show()
        save_path = os.path.join("outputs/confusion/plots", f"histogram_ch{ch}.png")
        plt.savefig(save_path)
