import hashlib
import random
import tiktoken
import numpy as np
from functools import lru_cache


tokenizer = tiktoken.get_encoding("cl100k_base")


@lru_cache(maxsize=1024)  # Adjust size as needed
def cached_encode(text):
    return tokenizer.encode(text)


def word_count(text):
    """Get the number of words in a string"""
    # TODO: Calculate word count in a better way
    return len(text.split(" "))


def hash_string(text):
    """Get the SHA256 hash of a string"""
    return hashlib.sha256(text.encode("utf-8")).hexdigest()


def num_tokens(text):
    """Get the number of tokens in a string"""
    if type(text) == str:
        return len(cached_encode(text))

    return np.array(list(map(num_tokens, text)))


def truncate(text, max_length=8192, short_ok=True):
    """Truncate text to max_length tokens"""
    assert max_length > 0

    if type(text) == str:
        encoding = cached_encode(text)

        if not short_ok and len(encoding) < max_length:
            print(f"WARNING: text is too short to be truncated")

        if len(encoding) < max_length:
            return text

        truncated_encoding = encoding[:max_length]
        return tokenizer.decode(truncated_encoding)

    return np.array([truncate(t, max_length) for t in text])

def sample_from_text(text, pct_size):
    # if not isinstance(pct_size, list):
    if type(text) == str:
        upper = 1-pct_size
        start = random.uniform(0, upper)
        start = int(num_tokens(text) * start)

        size = int(num_tokens(text) * pct_size)
        sample = tokenizer.decode(cached_encode(text)[start:start+size])
        return sample
    return np.array([sample_from_text(t, pct_size) for t in text])