from __future__ import annotations

import heapq
import itertools
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional

import jax
import numpy as onp
import scipy as osp
import simdjson as json
import tqdm.auto as tqdm
from jax import numpy as np
from jax import scipy as sp
from jax.flatten_util import ravel_pytree
from jax.tree_util import (
    tree_flatten,
    tree_leaves,
    tree_map,
    tree_reduce,
    tree_structure,
    tree_unflatten,
)

CODE_LANGS = [
    "Fortran",
    "Perl",
    "Motorola68KAssembly",
    "Ruby",
    "XML",
    "reStructuredText",
    "PowerShell",
    "Batchfile",
    "Smali",
    "VisualBasic.NET",
    "Pod6",
    "Makefile",
    "Lua",
    "JavaScript",
    "Hack",
    "Scala",
    "HTML",
    "XPixMap",
    "Python",
    "PHP",
    "CMake",
    "TSQL",
    "Haskell",
    "C++",
    "C",
    "CSS",
    "Dockerfile",
    "Objective-C",
    "Raku",
    "Java",
    "Smalltalk",
    "FortranFreeForm",
    "Shell",
    "TeX",
    "Julia",
    "Markdown",
    "Go",
]
DOMAIN_LANGS = ["arxiv", "books", "github", "web", "wikipedia"]
OTHER_LANGS = [
    "fy",
    "sah",
    "ga",
    "sa",
    "os",
    "cv",
    "ceb",
    "af",
    "br",
    "azb",
    "hr",
    "mhr",
    "lb",
    "uz",
    "ce",
    "mg",
    "nds",
    "xmf",
    "bpy",
    "new",
    "min",
    "arz",
    "nn",
    "tk",
    "pms",
    "ms",
    "gom",
    "la",
    "jbo",
    "mt",
    "sw",
]
DO_COUNT_CHARS = {
    'gpt2': False,
    'gpt3': False,
    'gpt3.5': False,
    'gpt4o': False,
    'llama': True,
    'llama3': False,
    'mixtral': True,
    'gemma': True,
}


def bytes_to_unicode():
    """
    MJ: STOLEN DIRECTLY FROM https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
    --------------
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


class PriorityQueue:
    def __init__(self, items=None, max_queue=True):
        self.pq = []
        self.removed = object()
        self.entry_finder = {}
        self.counter = itertools.count()
        self.max_queue = max_queue
        if items is not None:
            for el, priority in items:
                if self.max_queue:
                    priority = -priority
                assert el not in self
                count = next(self.counter)
                entry = [priority, count, el]
                self.entry_finder[el] = entry
                self.pq.append(entry)
            heapq.heapify(self.pq)

    def add(self, el, priority):
        if self.max_queue:
            priority = -priority
        if el in self:
            self.remove(el)
        count = next(self.counter)
        entry = [priority, count, el]
        self.entry_finder[el] = entry
        heapq.heappush(self.pq, entry)

    def remove(self, el):
        entry = self.entry_finder.pop(el)
        entry[-1] = self.removed

    def pop(self):
        while self.pq:
            priority, count, el = heapq.heappop(self.pq)
            if el is not self.removed:
                del self.entry_finder[el]
                if self.max_queue:
                    priority = -priority
                return el, priority
        raise KeyError("pop from an empty priority queue")

    def peek(self):
        while self.pq:
            priority, count, el = self.pq[0]
            if el is self.removed:
                heapq.heappop(self.pq)
                continue

            if self.max_queue:
                priority = -priority
            return el, priority
        raise KeyError("peek from an empty priority queue")

    def lookup(self, el, default=None):
        priority = self.entry_finder.get(el, (default,))[0]
        if self.max_queue:
            priority = -priority
        return priority

    def __getitem__(self, el):
        return self.entry_finder[el][0]

    def __contains__(self, el):
        return el in self.entry_finder

    def __len__(self):
        return len(self.entry_finder)


def sp_minimize(f, x0, *, disable_unary=False, **kwargs):
    unary = False
    if tree_structure(x0).num_nodes == 1 and not disable_unary:
        unary = True
        x0 = (x0,)

    # if isinstance(jax.eval_shape(f, *x0), jax.ShapeDtypeStruct):
    #     f = jax.jit(jax.value_and_grad(f))

    ty = np.float64 if jax.config.jax_enable_x64 else np.float32
    _, unravel = ravel_pytree(x0)

    def to_np(x):
        return tree_map(lambda t: t.astype(ty), unravel(x))

    def to_onp(x):
        return onp.asarray(ravel_pytree(x)[0]).astype(onp.float64)

    def f_wrapper(x):
        l, g = f(*to_np(x))
        return onp.array(l).astype(onp.float64), to_onp(g)

    inner_kwargs = {"jac": True, "method": "L-BFGS-B"}
    inner_kwargs.update(kwargs)

    if "callback" in inner_kwargs:
        callback = inner_kwargs["callback"]

        def callback_wrapper(xk, *args):
            return callback(*to_np(xk), *args)

        inner_kwargs["callback"] = callback_wrapper

    if "bounds" in inner_kwargs:
        bounds = inner_kwargs["bounds"]
        keep_feasible = False
        if isinstance(bounds, osp.optimize.Bounds):
            keep_feasible = bounds.keep_feasible
            bounds = (bounds.lb, bounds.ub)

        if isinstance(bounds, tuple):
            assert len(bounds) == 2
            lb, ub = bounds
            x0_shape = tree_map(lambda t: getattr(t, "shape", None), x0)
            lb_shape = tree_map(lambda t: getattr(t, "shape", None), lb)
            ub_shape = tree_map(lambda t: getattr(t, "shape", None), ub)
            if unary:
                assert lb_shape == ub_shape == x0_shape[0]
            else:
                assert lb_shape == ub_shape == x0_shape
            inner_kwargs["bounds"] = osp.optimize.Bounds(
                to_onp(lb), to_onp(ub), keep_feasible
            )

        else:
            raise NotImplementedError("Can only handle tuple and Bounds for bounds")

    if inner_kwargs["jac"] is not True:
        raise NotImplementedError("Only supports jac=True right now")

    # for p in ("hess", "hessp", "constraints"):
    #     if p in inner_kwargs:
    #         raise NotImplementedError(
    #             f"Have not implemented translation of {p} argument"
    #         )

    opt_min = osp.optimize.minimize(f_wrapper, to_onp(x0), **inner_kwargs)

    # TODO: This is super gross and could break things
    opt_min.fun = np.asarray(opt_min.fun)
    opt_min.x = to_np(opt_min.x)
    if unary:
        opt_min.x = opt_min.x[0]

    if hasattr(opt_min, "hess_inv"):
        hess_inv = opt_min.hess_inv

        def hess_inv_wrapper(x):
            # ravel takes care of unary input
            result = to_np(hess_inv(to_onp(x)))
            if unary:
                return result[0]
            return result

        opt_min.hess_inv = hess_inv_wrapper

    if hasattr(opt_min, "jac") and False:
        opt_min.jac = to_np(opt_min.jac)
        if unary:
            opt_min.jac = opt_min.jac[0]

    return opt_min


@dataclass
class Merge:
    rank: int
    l: str
    r: str
    lp: Optional[Merge] = None
    rp: Optional[Merge] = None
    lc: list[Merge] = field(default_factory=list)
    rc: list[Merge] = field(default_factory=list)

    @property
    def m(self):
        return self.l + self.r

    @property
    def c(self):
        return self.lc + self.rc

    def __str__(self):
        return f"{self.l} {self.r}"

    def __repr__(self):
        return f"{self.l}≀{self.r}"


def postprocess_merges(merges):
    producers = {}
    merge_order = []
    for i, merge_str in enumerate(merges):
        try:
            left, right = merge_str.split(" ")
        except ValueError:
            print(f"Broken merge {i}: {merge_str!r}")
            break
        merge = Merge(i + 1, left, right)
        merge_order.append(merge)
        producers[merge.m] = merge
        merge.lp = producers.setdefault(merge.l, Merge(0, merge.l, ""))
        merge.rp = producers.setdefault(merge.r, Merge(0, merge.r, ""))
        if merge.lp is not None:
            merge.lp.lc.append(merge)
        if merge.rp is not None:
            merge.rp.rc.append(merge)

    return merge_order, producers


def load_merges(fname):
    with Path(fname).open() as f:
        return [line.rstrip("\n") for line in f.readlines()[1:]]


def load_data(data_root, verbose=False):
    merges, producer = postprocess_merges(load_merges(data_root / "merges.txt"))

    pair_counts = {}
    byte_counts = {}
    P = partial(tqdm.tqdm, dynamic_ncols=True) if verbose else lambda x: x
    for item in P(list(data_root.iterdir())):
        if not item.is_dir() or item.name.startswith("."):
            continue

        # if (
        #     item.name in CODE_LANGS
        #     or item.name in OTHER_LANGS
        #     or item.name in ["code", "en", "wikipedia_uncleaned"]
        # ):
        #     continue
        # if (
        #     item.name in CODE_LANGS
        #     or item.name in OTHER_LANGS
        #     or (item.name in DOMAIN_LANGS and item.name != "github")
        #     or item.name == "code"
        # ):
            continue
        with (item / "1e07/all_pair_counts.json").open() as f:
            pair_counts[item.name] = json.load(f)

        with (item / "1e07/meta.json").open() as f:
            byte_counts[item.name] = json.load(f)["byte_count"]
            # byte_counts[item.name] = json.load(f)["char_count"]
            # byte_counts[item.name] = sum(pair_counts[item.name][0].values())
    return merges, pair_counts, byte_counts
