
import sys

import numpy as np
import torch

from watermark.mersenne import mersenne_rng

import pyximport
pyximport.install(reload_support=True, language_level=sys.version_info[0],
                  setup_args={"include_dirs":np.get_include()})
from watermark.levenshtein import levenshtein

def exp_sampling(probs, u):
    return torch.argmax(u ** (1 / probs), axis=1).unsqueeze(-1)

def generate_single(model, prompt, vocab_size, n, m, key):
    rng = mersenne_rng(key)
    xi = torch.tensor([rng.rand() for _ in range(n * vocab_size)]).view(n, vocab_size)
    shift = torch.randint(n, (1,))

    inputs = prompt.to(model.device)
    attn = torch.ones_like(inputs)
    for i in range(m):
        with torch.no_grad():
            output = model(inputs)

        probs = torch.nn.functional.softmax(output.logits[:, -1, :vocab_size], dim=-1).cpu()
        token = exp_sampling(probs, xi[(shift+i) % n, :]).to(model.device)
        inputs = torch.cat([inputs, token], dim=-1)
        attn = torch.cat([attn, attn.new_ones((attn.shape[0], 1))], dim=-1)

    return inputs.detach().cpu()

def transform_key_func(generator, n, vocab_size, eff_vocab_size=None):
    pi = torch.randperm(vocab_size, generator=generator)
    xi = torch.rand((n, 1), generator=generator)
    return xi, pi

def transform_sampling(probs, pi, xi):
    cdf = torch.cumsum(torch.gather(probs, 1, pi), 1)
    return torch.gather(pi, 1, torch.searchsorted(cdf, xi))

def gumbel_sampling(probs,pi,xi):
    return torch.argmax(xi ** (1/torch.gather(probs, 1, pi)),axis=1).unsqueeze(-1)

def gumbel_key_func(generator, n, vocab_size, eff_vocab_size=None):
    if eff_vocab_size is None:
        eff_vocab_size = vocab_size
        
    pi = torch.arange(eff_vocab_size)
    xi = torch.rand((n,eff_vocab_size), generator=generator)

    return xi, pi

def generate_batch(model, prompts, vocab_size, n, m, seeds, random_offset=True):
    batch_size = len(prompts)

    generator = torch.Generator()
    xis,pis = [],[]
    for seed in seeds:
        generator.manual_seed(int(seed))
        xi, pi = gumbel_key_func(generator, n, vocab_size)
        xis.append(xi.unsqueeze(0))
        pis.append(pi.unsqueeze(0))
    xis = torch.vstack(xis)
    pis = torch.vstack(pis)

    # deliberately not controlling this randomness with the generator
    if random_offset:
        offset = torch.randint(n, size=(batch_size,))
    else:
        offset = torch.zeros(size=(batch_size,), dtype=torch.int64)
    inputs = prompts.to(model.device)
    attn = torch.ones_like(inputs)
    past = None
    for i in range(m):
        with torch.no_grad():
            if past:
                output = model(inputs[:,-1:], past_key_values=past, attention_mask=attn)
            else:
                output = model(inputs)

        probs = torch.nn.functional.softmax(output.logits[:,-1], dim=-1).cpu()
        tokens = gumbel_sampling(probs, pis, xis[torch.arange(batch_size),(offset.squeeze()+i)%n]).to(model.device)
        inputs = torch.cat([inputs, tokens], dim=-1)

        past = output.past_key_values
        attn = torch.cat([attn, attn.new_ones((attn.shape[0], 1))], dim=-1)

    return inputs.detach().cpu()

def permutation_test(tokens,key,n,k,vocab_size,n_runs=100):
    rng = mersenne_rng(key)
    xi = np.array([rng.rand() for _ in range(n*vocab_size)], dtype=np.float32).reshape(n,vocab_size)
    test_result = detect(tokens,n,k,xi)

    p_val = 0
    for _ in range(n_runs):
        xi_alternative = np.random.rand(n, vocab_size).astype(np.float32)
        null_result = detect(tokens,n,k,xi_alternative)

        # assuming lower test values indicate presence of watermark
        p_val += null_result <= test_result

    return (p_val+1.0)/(n_runs+1.0)


def detect(tokens,n,k,xi,gamma=0.0):
    m = len(tokens)
    n = len(xi)

    A = np.empty((m-(k-1),n))
    for i in range(m-(k-1)):
        for j in range(n):
            A[i][j] = levenshtein(tokens[i:i+k],xi[(j+np.arange(k))%n],gamma)

    return np.min(A)
