import numpy as np
import random
import string


def logit(p):
    return np.log((p + 1e-10) / (1 - p + 1e-10))


def sample_vectorized(
    prob: np.array,
    n_samples: int,
    use_laplace_smoothing: bool = False,
    return_samples: bool = False,
    bayesian: bool = False,
    stability: bool = True,
    unwatermark_prob: np.array = None,
    disable_watermark_every: int = 0,
):
    shape = prob.shape
    n_choice = shape[-1]

    prob = prob.reshape(-1, n_choice)
    if unwatermark_prob is not None:
        unwatermark_prob = unwatermark_prob.reshape(-1, n_choice)
    
    
    sample_counts = np.zeros_like(prob)
    prob = np.asarray(prob).astype("float64")  # Important for fixing casting issues

    if stability:
        prob = np.maximum(prob, 0)
        prob = prob / np.sum(prob, axis=1, keepdims=True)
        if unwatermark_prob is not None:
            unwatermark_prob = np.asarray(unwatermark_prob).astype("float64")
            unwatermark_prob = np.maximum(unwatermark_prob, 0)
            unwatermark_prob = unwatermark_prob / np.sum(unwatermark_prob, axis=1, keepdims=True)

    for j in range(prob.shape[0]):
        if disable_watermark_every > 0:
            n_watermark = n_samples - (n_samples // disable_watermark_every)
            n_unwatermark = n_samples  - n_watermark
            sample_counts[j] = np.random.multinomial(n_watermark, prob[j]) + np.random.multinomial(n_unwatermark, unwatermark_prob[j])
        else:
            sample_counts[j] = np.random.multinomial(n_samples, prob[j])
        

    hat_p = sample_counts / n_samples
    hat_p = hat_p.reshape(shape)

    if return_samples:
        return sample_counts.reshape(shape)

    return hat_p


def stanford_sampling(
    prob: np.array,
    n_samples: int,
    key_size: int,
    return_samples: bool = False,
    bayesian: bool = False,
    unwatermark_prob: np.array = None,
    disable_watermark_every: int = 0,
):
    # Sample using gumbel sampling with offsetting the key randomly everytime
    shape = prob.shape
    n_choice = shape[-1]

    prob = prob.reshape(-1, n_choice)
    if unwatermark_prob is not None:
        unwatermark_prob = unwatermark_prob.reshape(-1, n_choice)
        unwatermark_prob = np.maximum(unwatermark_prob, 0)
        unwatermark_prob = unwatermark_prob / np.sum(unwatermark_prob, axis=1, keepdims=True)

    key = np.random.uniform(0, 1, size=(n_choice, key_size))

    hat_p = np.zeros_like(prob)

    for i in range(prob.shape[0]):
        sample_counts = np.zeros(n_choice)
        for j in range(n_samples):
            xi = key[:, np.random.randint(0, key_size)]
            
            if disable_watermark_every > 0 and j % disable_watermark_every == 0:
                sample = np.random.choice(n_choice, p=unwatermark_prob[i, :])
            else:
                sample = np.argmax(xi ** (1 / prob[i, :]))
            sample_counts[sample] += 1

        if bayesian:
            alpha_prior = np.ones(n_choice)
            alpha_posterior = alpha_prior + sample_counts

            hat_p[i, :] = np.random.dirichlet(alpha_posterior)
        else:
            hat_p[i, :] = sample_counts / n_samples

    if return_samples:
        return sample_counts

    return hat_p.reshape(shape)


def _dip_reweight(probs, permutation, alpha):
    permuted_probs = probs[:, permutation]

    cdf = np.cumsum(permuted_probs, axis=-1)

    F = np.maximum(cdf - alpha, 0) + np.maximum(cdf - (1 - alpha), 0)
    F = F[0]
    F = np.concatenate([np.array([0]), F], axis=-1)
    reweighted_probs = F[1:] - F[:-1]
    reweighted_probs = np.maximum(
        reweighted_probs, 0
    )  # Ensure non-negativity due to numerical errors

    reweighted_probs = reweighted_probs[np.argsort(permutation)]

    return reweighted_probs.reshape(1, -1)


def dip_reweight(prob: np.array, alpha: float, same_permutation: bool = False):
    n, m = prob.shape
    p_reweighted = np.zeros_like(prob)

    prob = prob / prob.sum(axis=1, keepdims=True)

    if same_permutation:
        indices = np.arange(m)
        np.random.shuffle(indices)

    for i in range(n):
        if not same_permutation:
            indices = np.arange(m)
            np.random.shuffle(indices)

        p_reweighted[i] = _dip_reweight(prob[i].reshape(1, -1), indices, alpha)

    p_reweighted = p_reweighted / p_reweighted.sum(axis=1, keepdims=True).astype(
        np.float64
    )

    return p_reweighted


def find_argmin(data):
    flat_index_min = np.argmin(data)

    index_min = np.unravel_index(flat_index_min, data.shape)

    return index_min


def generate_random_prefix():
    letters = string.ascii_letters
    digits = string.digits
    special_chars = "!@#$%^&*()-_=+{}[];:<>,.?|`~"

    password = [random.choice(special_chars)]

    remaining_chars = random.choices(letters + digits + special_chars, k=4)

    password.extend(remaining_chars)

    random.shuffle(password)

    return "".join(password)


def get_2_random_fruits():
    fruits = [
        "apple",
        "banana",
        "cherry",
        "date",
        "elderberry",
        "fig",
        "grape",
        "honeydew",
        "kiwi",
        "lemon",
        "mango",
        "nectarine",
        "orange",
        "papaya",
        "quince",
        "raspberry",
        "strawberry",
        "tangerine",
        "ugli",
        "watermelon",
    ]
    return random.sample(fruits, 2)
