from collections import Counter
import json
import hashlib
import datetime
from typing import Any, Set, List
import numpy as np

class Colors:
    HEADER = '\033[95m'
    BLUE = '\033[94m'
    CYAN = '\033[96m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

YS, LOGP = (0, 1)

def flatten(xs):
    if len(xs) == 0:
        return ()
    else:
        ys, y = xs
        return flatten(ys) + (y,)


def flatten_beam(fst, C, is_bytes: bool = True) -> List[str]:
    """Flatten a list of beams into a list of strings."""
    if not C:
        return []
    if isinstance(list(C)[0], tuple):
        if is_bytes:
            return ["".join(chr(int(c)) if c not in ["⭑", "🦄", "🙈"] else c for c in b[YS]) for b in C]
        else:
            raise NotImplementedError

def _log(level: str, *msgs: Any) -> None:
    if not level:
        return

    if type(level) == bool:
        level = "DEBUG"
    
    now = datetime.datetime.now()
    timestamp = now.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
    
    level_colors = {
        "DEBUG": Colors.CYAN,
        "INFO": Colors.GREEN,
        "WARNING": Colors.YELLOW,
        "ERROR": Colors.RED,
        "CRITICAL": Colors.RED + Colors.BOLD
    }
    
    color = level_colors.get(level.upper(), Colors.BLUE)
    
    level_str = f"{color}[{level.upper()}]{Colors.END}"
    time_str = f"{Colors.BLUE}{timestamp}{Colors.END}"
    message = " ".join(str(msg) for msg in msgs)
    
    print(f"| {time_str} | {level_str} | {Colors.BOLD}{message}{Colors.END}")


def _hash_cfg(cfg) -> str:
    """Stable, hashable fingerprint for any config-like object."""
    if cfg is None:
        return ""
    if not isinstance(cfg, str):
        cfg = json.dumps(cfg, sort_keys=True, default=str)    
    return hashlib.blake2b(cfg.encode(), digest_size=8).hexdigest()


def _find_duplicates(seq: List[str]) -> List[str]:
    """Return a list of elements that appear more than once in *seq* (preserves order)."""
    counts = Counter(seq)
    return [x for x in seq if counts[x] > 1 and seq.index(x) == seq.index(x)]  # type: ignore[comparison-overlap]


def is_decomposition_prefix_free(q_list: List[str], r_list: List[str]) -> bool:
    """Validate prefix free quotient and remainder.

    Parameters
    ----------
    q : set[str]
        Set of strings in which no element may prefix another element.
    r : set[str]
        Set of strings in which no element may be a suffix of any element in q.

    Returns
    -------
    bool
        True if all checks pass. Raises an AssertionError otherwise.
    """
    
    dup_q = [x for x, c in Counter(q_list).items() if c > 1]
    dup_r = [x for x, c in Counter(r_list).items() if c > 1]
    assert not dup_q, f"Duplicate elements in q: {dup_q}"
    assert not dup_r, f"Duplicate elements in r: {dup_r}"

    q: Set[str] = set(q_list)
    r: Set[str] = set(r_list)

    for a in q:
        for b in q:
            if a == b:
                continue 
            assert not b.startswith(a), (
                f"Prefix free error in q: element '{a}' is a prefix of '{b}'. "
            )

    for q_elem in q:
        for r_elem in r:
            assert not r_elem.startswith(q_elem), (
                f"Prefix free error in q: element '{q_elem}' is a prefix of with r element '{r_elem}'. "
            )

    return True

def kl_divergence(p, q):
    """ Compute KL divergence of two vectors, K(p || q).
    NOTE: If any value in q is 0.0 then the KL-divergence is infinite.
    """
    nz = p.nonzero()
    p = p[nz]
    q = q[nz]
    return p.dot(np.log(p) - np.log(q)) / np.log(2)

def get_jsd(res_prune, res_no_prune, verbose=True):
    jds = []
    speeds_no_prune = []
    speeds_prune = []
    not_nans = 0
    nans = 0
    for i, (P_log, Q_log) in enumerate(zip(res_prune, res_no_prune)):
        xs = sorted(set(P_log) | set(Q_log))
        logp = np.array([P_log.get(x, -np.inf) for x in xs])
        logq = np.array([Q_log.get(x, -np.inf) for x in xs])
        
        p= np.exp(logp)
        q = np.exp(logq)
        
        m = 0.5 * (p + q)
        jd = 0.5 * (
            kl_divergence(p, m) +
            kl_divergence(q, m)
        )
        if not np.isnan(jd):
            jds.append(jd)
            speeds_prune.append(res_prune['times'][i])
            speeds_no_prune.append(res_no_prune['times'][i])
            not_nans += 1
        else:
            nans += 1
    
    
    jds = np.array(jds)
    if verbose:
        print(f"Number of NaNs: {nans} out of {not_nans + nans}")
        print(f"JSD: {np.mean(jds)}")
        print(f"Speed no prune: {np.mean(speeds_no_prune)}")
        print(f"Speed prune: {np.mean(speeds_prune)}")
    return jds


def get_jsd_against_genlm(res_prune, genlm, verbose=True, atol=1e-6):
    jds = []
    for i, (P_log, Q_log) in enumerate(zip(res_prune, genlm)):
        assert isinstance(P_log, dict) and isinstance(Q_log, dict), \
            f"Pair {i}: inputs must be dicts."
        assert P_log, f"Pair {i}: P_log is empty."
        assert Q_log, f"Pair {i}: Q_log is empty."
        xs = sorted(set(P_log) | set(Q_log))
        logp = np.array([P_log.get(x, -np.inf) for x in xs])

        logq = np.array([Q_log.get(x, -np.inf) for x in xs])
        #print("pq: ", logp, logq)
        
        p= np.exp(logp)
        q = np.exp(logq)
        assert np.isclose(p.sum(), 1.0, atol=atol), \
           f"Pair {i}: P not normalised (sum = {p.sum():.6g})."
        assert np.isclose(q.sum(), 1.0, atol=atol), \
           f"Pair {i}: Q not normalised (sum = {q.sum():.6g})."
        m = 0.5 * (p + q)
        assert np.all((p == 0) | (m > 0)), f"Pair {i}: mixture has zero where P>0."
        assert np.all((q == 0) | (m > 0)), f"Pair {i}: mixture has zero where Q>0."

        jd = 0.5 * (
            kl_divergence(p, m) +
            kl_divergence(q, m)
        )
        tol = 1e-12
        if jd < -tol or not np.isfinite(jd):
            raise ValueError(f"Pair {i}: JSD invalid ({jd})")
        jds.append(max(jd, 0.0))
    jds = np.array(jds)
    if verbose:
        print(f"JSD: {np.mean(jds)}")
    return jds