import os
from tqdm import tqdm
import numpy as np
from scipy import stats
import random
from collections import Counter


def generate_stanford_test_data(watermark, temperature, n_samples, max_new_tokens, path: str, full_path: str):
    """Generate data for the Fixed-Sampling detection test."""
    
    prompt = "This is the story of"
    prompt = watermark.tokenizer.encode(prompt, return_tensors='pt', max_length=2048)
    prompt = prompt.to(watermark.model.device)
    
    # Checks the number of line starting with ###NEW_RESPONSE### already here
    if os.path.exists(full_path):
        with open(full_path
                    ) as f:
            lines = f.readlines()\
            # Count the number of lines starting with ###NEW_RESPONSE###
            n_lines = sum([1 for line in lines if line.startswith("###NEW_RESPONSE###")])
            
    else:
        n_lines = 0
                
    for _ in tqdm(range(n_samples - n_lines)):
        
        generation_output = watermark.generate_key(
                        prompt, 
                        temperature=temperature,
                        max_new_tokens=max_new_tokens,
                        key_number=0)
            
        response = watermark.tokenizer.decode(generation_output.sequences[0])

        # Check if the directory exists
        if not os.path.exists(path):
            os.makedirs(path)
            
        with open(full_path, "a") as f:
            # add a newline to the end of the file
            f.write("###NEW_RESPONSE###" + response + "\n")
            
def _rarefaction_curve(data, num_samples=1000, trials=500):

    max_samples = min(num_samples, len(data))
    unique_counts = np.zeros(max_samples)
    
    # Perform multiple trials to average the curve
    for _ in range(trials):
        np.random.shuffle(data)
        seen_sentences = set()
        cumulative_uniques = []

        for i in range(1, max_samples + 1):
            seen_sentences.add(data[i - 1])
            cumulative_uniques.append(len(seen_sentences))
        
        unique_counts += np.array(cumulative_uniques)

    unique_counts /= trials
    return unique_counts
            
            
def test_stanford(data, old: bool = False):
    
    rarefaction = _rarefaction_curve(data, num_samples=len(data))
    return stats.mannwhitneyu(rarefaction, np.arange(len(rarefaction))).pvalue