from decimal import Decimal
from copy import deepcopy
import logging
from collections import Counter
from decimal import getcontext
import numpy as np
from tqdm.auto import tqdm
from sklearn.cluster import KMeans
import scipy.stats
import torch
from torch.optim import SGD, Adam
import torch.nn as nn


def get_kmeans_symbols_and_codebook(vec, levels, codebook_dtype):
    kmeans = KMeans(n_clusters=levels).fit(vec.reshape(-1, 1))
    codebook = kmeans.cluster_centers_.astype(codebook_dtype)[:, 0]
    symbols = kmeans.labels_
    return symbols, codebook


def get_random_symbols_and_codebook(vec, levels, codebook_dtype):
    largest = max(np.max(vec), np.abs(np.min(vec)))
    initvals = np.linspace(-largest - 1e-6, largest + 1e-6, levels + 1)
    assignments = np.digitize(vec, initvals) - 1
    centroids = []
    for i in range(levels):
        aux = vec[assignments == i]
        if len(aux) > 0:
            centroids.append(np.mean(aux))
        else:
            centroids.append(initvals[i])
    codebook = np.array(centroids, dtype=codebook_dtype)
    symbols = np.array(assignments)
    return symbols, codebook

def do_arithmetic_encoding(symbols, probabilities, levels):
    entropy_est = scipy.stats.entropy(probabilities, base=2)
    logging.info(f"Entropy: {entropy_est:.2f} bits")
    is_too_large_to_run = len(symbols) > int(1e4)
    if is_too_large_to_run:
        coded_symbols_size = np.ceil(len(symbols) * entropy_est) + 1.
    else:
        getcontext().prec = int(1.1 * np.log10(levels) * len(symbols))
        coded_symbols_size = len(encode(symbols, probabilities))
    return symbols, coded_symbols_size

def do_huffman_encoding(vec):
    vec_str = ""
    for i in range(len(vec)):
        vec_str += str(vec[i])
    freq = dict(Counter(vec_str))
    freq = sorted(freq.items(), key=lambda x: x[1], reverse=True)
    node = make_tree(freq)
    encoding = huffman_code_tree(node)

    coded_symbols_len = 0
    for i in range(len(vec)):
        key = str(vec[i])
        key_size = len(encoding[key])
        coded_symbols_len += key_size
    return encoding, coded_symbols_len

def get_message_len(coded_symbols_size, codebook, max_count):
    codebook_bits_size = 16 if codebook.dtype == np.float16 else 32
    probability_bits = int(np.ceil(np.log2(max_count)) * len(codebook))
    codebook_bits = len(codebook) * codebook_bits_size
    summary = f"encoding {coded_symbols_size}, codebook {codebook_bits} probs {probability_bits}"
    logging.info(summary)
    message_len = coded_symbols_size + codebook_bits + probability_bits
    return message_len

def quantize_vector(
    vec, levels=2**2 + 1, use_kmeans=False, encoding_type="arithmetic"
):
    codebook_dtype = np.float16
    if use_kmeans:
        symbols, codebook = get_kmeans_symbols_and_codebook(vec, levels, codebook_dtype)
    else:
        symbols, codebook = get_random_symbols_and_codebook(vec, levels, codebook_dtype)

    logging.info(f"KMeans: {use_kmeans}, Levels: {levels}, Algorithm: {encoding_type}")
    probabilities = np.array([np.mean(symbols == i) for i in range(levels)])
    logging.info(f"probs {probabilities}")

    if encoding_type == "arithmetic":
        _, coded_symbols_size = do_arithmetic_encoding(
            symbols, probabilities, levels
        )
    elif encoding_type == "huff":
        _, coded_symbols_size = do_huffman_encoding(symbols)
    else:
        NotImplementedError
    decoded_vec = np.zeros(shape=(len(vec)))
    for k in range(len(codebook)):
        decoded_vec[symbols == k] = codebook[k]

    message_len = get_message_len(coded_symbols_size, codebook, len(symbols))
    logging.info(f"Message Len: {message_len}")
    return decoded_vec, message_len