import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import sentencepiece as spm
from datasets import load_dataset
from matplotlib.colors import to_rgba
from tqdm import tqdm

from spectrum import plotting
from spectrum.plotting import fmt_pow10


def load_data():
    def make_single_file():
        if Path("openwebtext_all.txt").exists():
            return
        i = 0
        dataset = load_dataset("openwebtext", split="train", trust_remote_code=True)
        with open("openwebtext_all.txt", "w", encoding="utf-8") as f:
            for example in tqdm(dataset, total=len(dataset)):
                text = example["text"].replace("\n", " ").strip()
                if text:
                    f.write(text + "\n")
                    i += 1
                    if i % 100_000 == 0:
                        f.flush()

    def model_path(vocab_size):
        return Path(f"sp_bpe_{vocab_size}.model")

    def train_spm(vocab_size):
        if model_path(vocab_size).exists():
            return
        spm.SentencePieceTrainer.train(
            input="openwebtext_all.txt",
            model_prefix=f"sp_bpe_{vocab_size}",
            vocab_size=vocab_size,
            model_type="bpe",
            input_sentence_size=2_000_000,
            max_sentence_length=16768,
            pad_id=0,
            unk_id=1,
            bos_id=2,
            eos_id=3,
            user_defined_symbols=["<mask>"],
        )

    def process_shard_to_numpy(
        shard_index, num_shards, vocab_size, model_file, output_dir
    ):
        os.makedirs(output_dir, exist_ok=True)

        token_path = os.path.join(output_dir, f"token_freq_shard_{shard_index}.npy")
        bigram_path = os.path.join(output_dir, f"bigram_freq_shard_{shard_index}.npy")
        if os.path.exists(token_path) and os.path.exists(bigram_path):
            print(f"Shard {shard_index} already processed. Skipping.")
            return

        sp = spm.SentencePieceProcessor(model_file=model_file)
        dataset = load_dataset("openwebtext", split="train", trust_remote_code=True)
        shard = dataset.shard(num_shards, index=shard_index)

        token_freq = np.zeros(vocab_size, dtype=np.int64)
        bigram_freq = np.zeros((vocab_size, vocab_size), dtype=np.int64)

        for example in tqdm(shard, desc=f"Shard {shard_index}"):
            text = example["text"].replace("\n", " ").strip()
            if not text:
                continue

            tokens = sp.encode(text, out_type=int)
            tokens = np.array(tokens, dtype=np.int32)

            if len(tokens) < 2:
                continue

            np.add.at(token_freq, tokens[:-1], 1)
            np.add.at(bigram_freq, (tokens[:-1], tokens[1:]), 1)

        np.save(token_path, token_freq)
        np.save(bigram_path, bigram_freq)

    def compute_freqs_with_saves(vocab_size):
        if vocab_size == 10_000 or vocab_size == 1_000:
            num_shards = 20
        else:
            num_shards = 4

        output_dir = f"freq_shards_{vocab_size}"
        model_file = f"sp_bpe_{vocab_size}.model"

        for i in range(num_shards):
            process_shard_to_numpy(i, num_shards, vocab_size, model_file, output_dir)

    def aggregate_numpy_shards(vocab_size):
        output_dir = "freqs"
        Path(output_dir).mkdir(parents=True, exist_ok=True)

        token_freq_filepath = f"token_freq_total_{vocab_size}.npy"
        if (Path(output_dir) / token_freq_filepath).exists():
            return

        token_freq_total = np.zeros(vocab_size, dtype=np.int64)
        bigram_freq_total = np.zeros((vocab_size, vocab_size), dtype=np.int64)

        shards_dir = f"freq_shards_{vocab_size}"
        shard_files = sorted(
            [f for f in os.listdir(shards_dir) if f.startswith("token_freq_shard_")]
        )

        for tf_file in tqdm(shard_files, desc="Aggregating"):
            idx = tf_file.split("_")[-1].split(".")[0]
            bf_file = f"bigram_freq_shard_{idx}.npy"

            token_freq = np.load(os.path.join(shards_dir, tf_file))
            bigram_freq = np.load(os.path.join(shards_dir, bf_file))

            token_freq_total += token_freq
            bigram_freq_total += bigram_freq

        np.save(os.path.join(output_dir, token_freq_filepath), token_freq_total)
        np.save(
            os.path.join(output_dir, f"bigram_freq_total_{vocab_size}.npy"),
            bigram_freq_total,
        )

    def load_freqs(vocab_size):
        output_dir = "freqs"
        freqs = np.load(os.path.join(output_dir, f"token_freq_total_{vocab_size}.npy"))
        cond_freqs = np.load(
            os.path.join(output_dir, f"bigram_freq_total_{vocab_size}.npy")
        )
        return freqs, cond_freqs

    def process(vocab_size):
        print(vocab_size)
        train_spm(vocab_size)
        compute_freqs_with_saves(vocab_size)
        aggregate_numpy_shards(vocab_size)
        pi, cpi = load_freqs(vocab_size)
        sp = spm.SentencePieceProcessor(model_file=f"sp_bpe_{vocab_size}.model")
        vocab_size = sp.get_piece_size()
        id_to_token = [sp.id_to_piece(i) for i in range(vocab_size)]
        return pi, cpi, id_to_token

    vocab_sizes = [
        1_000,
        3_162,
        10_000,
        31_622,
    ]

    # freqs, cond_freqs, id_to_token = process(31_622)
    freqs, cond_freqs, id_to_token = [], [], []
    for vocab_size in vocab_sizes:
        freqs_, cond_freqs_, id_to_token_ = process(vocab_size)
        freqs.append(freqs_)
        cond_freqs.append(cond_freqs_)
        id_to_token.append(id_to_token_)

    return {
        "freqs": freqs,
        "cond_freqs": cond_freqs,
        "vocab_sizes": vocab_sizes,
        "vocab": id_to_token,
    }


def postprocess(data):
    all_freqs, all_cond_freqs = data["freqs"], data["cond_freqs"]
    vocab_sizes = data["vocab_sizes"]

    all_freqs = [_.astype(np.float64) for _ in all_freqs]
    all_cond_freqs = [_.astype(np.float64) for _ in all_cond_freqs]

    for i, d in tqdm(enumerate(vocab_sizes), total=len(vocab_sizes)):
        freqs = all_freqs[i]
        sort_idx = np.argsort(-freqs)
        freqs = freqs[sort_idx]
        freqs /= np.sum(freqs)
        all_freqs[i] = freqs

        cond_freqs = all_cond_freqs[i][sort_idx]
        for j in tqdm(range(d), total=d, leave=False):
            sort_idx = np.argsort(-cond_freqs[j])
            cond_freqs[j, :] = cond_freqs[j, sort_idx]

        sums = np.sum(cond_freqs, axis=1, keepdims=True)
        sums[sums < 1] = 1
        cond_freqs /= sums
        all_cond_freqs[i] = cond_freqs

    data["freqs"] = all_freqs
    data["cond_freqs"] = all_cond_freqs

    return data


def settings(plt):
    plotting.update_style(plt, ncols=3)


def make_figure(fig, data):
    axes = [
        fig.add_subplot(121),
        fig.add_subplot(122),
    ]

    all_freqs, all_cond_freqs = data["freqs"], data["cond_freqs"]
    vocab_sizes = data["vocab_sizes"]

    def subsample(xs, ys=None, zs=None, n=100):
        if len(xs) < n:
            if ys is not None:
                if zs is not None:
                    return xs, ys, zs
                return xs, ys
            return xs

        lin_idx = np.linspace(0, len(xs) - 1, n).astype(int)
        log_idx = np.logspace(0, np.log10(len(xs) - 1), n).astype(int)
        both = np.unique(np.concatenate([lin_idx, log_idx]))
        if ys is not None:
            if zs is not None:
                return xs[both], ys[both], zs[both]
            return xs[both], ys[both]
        return xs[both]

    def blend(color, alpha):
        r, g, b, _ = to_rgba(color)
        r_blend = alpha * r + (1 - alpha) * 1.0
        g_blend = alpha * g + (1 - alpha) * 1.0
        b_blend = alpha * b + (1 - alpha) * 1.0
        return (r_blend, g_blend, b_blend)

    cmap = plt.get_cmap("YlOrBr")

    d = 1_000
    d = 3_162
    d = 10_000
    idx = np.where(np.array(vocab_sizes) == d)[0][0]
    freqs = all_freqs[idx]
    cond_freqs = all_cond_freqs[idx]

    xs = np.arange(1, d + 1)
    axes[0].plot(*subsample(xs, freqs), color="k", label="data")
    axes[0].set_xscale("log")
    axes[0].set_yscale("log")

    max_freq = 1 / (np.log(len(freqs)))
    min_freq = (1 / len(freqs)) / (np.log(len(freqs)))

    (line_h_median,) = axes[1].plot(
        *subsample(xs, np.median(cond_freqs, axis=0)), color="k", label="median"
    )
    upper = np.percentile(cond_freqs, 95, axis=0)
    lower = np.percentile(cond_freqs, 5, axis=0)
    (line_h_fill1) = axes[1].fill_between(
        *subsample(xs, lower, upper),
        color=blend("b", alpha=0.2),
        label="5-95\%",
    )
    upper = np.percentile(cond_freqs, 90, axis=0)
    lower = np.percentile(cond_freqs, 10, axis=0)
    (line_h_fill2) = axes[1].fill_between(
        *subsample(xs, lower, upper),
        label="10-90\%",
        color=blend("b", alpha=0.4),
    )
    upper = np.percentile(cond_freqs, 75, axis=0)
    lower = np.percentile(cond_freqs, 25, axis=0)
    (line_h_fill3) = axes[1].fill_between(
        *subsample(xs, lower, upper),
        label="25-75\%",
        color=blend("b", alpha=0.6),
    )

    if False:
        colors = cmap(np.linspace(0.3, 1, len(freqs)))

        def random_permutation(x):
            return np.random.permutation(x)

        freqs_to_actually_plot = subsample(np.arange(len(freqs)), n=1000)

        for i in random_permutation(range(len(freqs))):
            if i not in freqs_to_actually_plot:
                continue
            sort_idx = np.argsort(-cond_freqs[i])
            ys = cond_freqs[i, sort_idx].astype(np.float64)
            if np.sum(ys) == 0:
                continue
            ys /= np.sum(ys)

            axes[1].plot(*subsample(xs, ys), color=colors[i], alpha=0.05)

    axes[1].set_xscale("log")
    axes[1].set_yscale("log")
    axes[0].set_ylim([min_freq / 10, 5 * max_freq])
    axes[1].set_ylim([min_freq / 10, 5 * max_freq])
    axes[0].set_xlim([1, len(freqs) * 1.05])
    axes[1].set_xlim([1, len(freqs) * 1.05])

    xs = np.arange(1, d + 1)
    axes[0].plot(
        *subsample(xs, (1 / xs) / np.log(d)),
        color="k",
        linestyle="--",
        label="$\\propto 1/k$",
    )
    (line_h_pred,) = axes[1].plot(
        *subsample(xs, (1 / xs) / np.log(d)),
        color="k",
        linestyle="--",
        label="$\\propto 1/k$",
    )
    #
    axes[0].legend(
        frameon=False,
        loc="lower left",
        fontsize=8,
        borderpad=0.25,
        borderaxespad=0.25,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
    )
    legend1 = axes[1].legend(
        [line_h_median, line_h_pred],
        [line_h_median.get_label(), line_h_pred.get_label()],
        loc="lower left",
        frameon=False,
        fontsize=8,
        borderpad=0.25,
        borderaxespad=0.25,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
    )
    #
    axes[1].add_artist(legend1)
    axes[1].legend(
        [line_h_fill1, line_h_fill2, line_h_fill3],
        [line_h_fill1.get_label(), line_h_fill2.get_label(), line_h_fill3.get_label()],
        loc="best",
        bbox_to_anchor=(0, 0, 1, 0.85),
        frameon=False,
        fontsize=8,
        borderpad=0.0,
        borderaxespad=0.0,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
    )

    axes[0].set_ylim((10**-6.5, 10**-0.5))
    axes[1].set_ylim((10**-6.5, 10**-0.5))
    axes[0].set_yticks([10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1])
    axes[1].set_yticks([10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1])
    axes[0].set_yticklabels(map(fmt_pow10, [None, 10**-5, None, 10**-3, None, 10**-1]))
    axes[1].set_yticklabels([])
    axes[1].set_ylabel("")
    #

    axes[0].set_xticks([10**0, 10**1, 10**2, 10**3, 10**4])
    axes[0].set_xticklabels(map(fmt_pow10, [10**0, 10**1, None, 10**3, 10**4]))
    axes[1].set_xticks([10**0, 10**1, 10**2, 10**3, 10**4])
    axes[1].set_xticklabels(map(fmt_pow10, [10**0, 10**1, None, 10**3, 10**4]))

    axes[0].set_xlabel("Rank of word $k$", labelpad=-5)
    axes[1].set_xlabel("Rank of word $k$", labelpad=-5)
    axes[0].set_ylabel("Frequency")
    # axes[1].set_ylabel("Frequency")
    axes[0].set_title("Word frequencies\n$\pi_k$", y=0.75)
    axes[1].set_title("Conditional frequencies\n $\pi_{k \\vert j}$", y=0.75)

    fig.tight_layout(pad=0.05)


if __name__ == "__main__":
    settings(plt)
    fig = plt.figure()
    data = load_data()
    data = postprocess(data)
    make_figure(fig, data)
    plt.savefig("figs/freqs.pdf")
    plt.close(fig)
