
import spacy
import textstat
import nltk
import numpy as np
from collections import Counter
from scipy.stats import entropy
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize
from nltk.probability import FreqDist
from nltk import pos_tag, bigrams, trigrams
import string

# Ensure necessary NLTK downloads
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('brown')
nltk.download('stopwords')
nltk.download('cmudict')


# 
# By using this file, you are agreeing to this product's EULA
#
# This product can be obtained in https://anonymous.4open.science/r/SAFE-ICLR
#
# Copyright ©2024-2025 XXXX-1
#


# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# Load CMU Pronouncing Dictionary for syllable counting
pron_dict = nltk.corpus.cmudict.dict()

def lexical_diversity(sentence):
    words = word_tokenize(sentence)
    return len(set(words)) / len(words) if words else 0


def get_named_entity(sentence):
	doc = nlp(sentence)
	return len([ent for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "LOC", "EVENT", "LAW", "WORK_OF_ART", "PRODUCT", "LANGUAGE"]]) 


def get_named_entity_density(sentence):
    doc = nlp(sentence)
    return len([ent for ent in doc.ents]) / len(doc) if len(doc) > 0 else 0

def get_parse_tree_depth(sentence):
    doc = nlp(sentence)
    return max([len(list(token.ancestors)) for token in doc])

def shannon_entropy(sentence):
    words = word_tokenize(sentence)
    freq_dist = FreqDist(words)
    probs = np.array(list(freq_dist.values())) / len(words)
    return entropy(probs, base=2) if len(words) > 1 else 0

def average_wordnet_synsets(sentence):
    words = [token.text for token in nlp(sentence) if token.pos_ in {'NOUN', 'VERB', 'ADJ'}]
    synset_counts = [len(wn.synsets(word)) for word in words]
    return sum(synset_counts) / len(synset_counts) if synset_counts else 0

def noun_verb_ratio(sentence):
    doc = nlp(sentence)
    noun_count = sum(1 for token in doc if token.pos_ == "NOUN")
    verb_count = sum(1 for token in doc if token.pos_ == "VERB")
    return noun_count / verb_count if verb_count != 0 else 0

def stopword_ratio(sentence):
    from nltk.corpus import stopwords
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(sentence)
    stopword_count = sum(1 for word in words if word.lower() in stop_words)
    return stopword_count / len(words) if words else 0

def punctuation_ratio(sentence):
    punct_count = sum(1 for char in sentence if char in string.punctuation)
    return punct_count / len(sentence) if len(sentence) > 0 else 0

def average_word_length(sentence):
    words = word_tokenize(sentence)
    return np.mean([len(word) for word in words]) if words else 0

def syllable_count(sentence):
	return textstat.syllable_count(sentence)

def word_count(sentence):
    return len(word_tokenize(sentence))

def unique_word_count(sentence):
    words = word_tokenize(sentence)
    return len(set(words))

def bigram_probability(sentence):
    words = word_tokenize(sentence)
    bigram_freq = Counter(bigrams(words))
    return sum(bigram_freq.values()) / len(words) if words else 0

def trigram_probability(sentence):
    words = word_tokenize(sentence)
    trigram_freq = Counter(trigrams(words))
    return sum(trigram_freq.values()) / len(words) if words else 0

def pronoun_density(sentence):
    words = word_tokenize(sentence)
    pos_tags = pos_tag(words)
    pronouns = sum(1 for word, tag in pos_tags if tag.startswith('PRP'))
    return pronouns / len(words) if words else 0

def passive_voice_ratio(sentence):
    doc = nlp(sentence)
    passive_sentences = sum(1 for token in doc if token.dep_ == "auxpass")
    return passive_sentences / len(doc) if len(doc) > 0 else 0

def analyze_sentence(sentence):
    tmp = [
    	#  0 --- fraction of unique words
        lexical_diversity(sentence),	
    	#  1 --- fraction of sentences that are entities
    	get_named_entity_density(sentence),
        #  2 --- tree depth
    	get_parse_tree_depth(sentence),	
        #  3 --- readability score
        textstat.flesch_reading_ease(sentence),
        #  4 --- Entropy of the character or word distribution
    	shannon_entropy(sentence),		
        #  5 --- Words with many meanings (polysemy) can make text more semantically complex
    	average_wordnet_synsets(sentence),
        #  6 --- Number of nouns / number of verbs.
    	noun_verb_ratio(sentence),		
        #  7 --- Proportion of stopwords (like "the", "is", "and"). 
    	stopword_ratio(sentence),			
        #  8 --- Ratio of punctuation marks to words or characters. 
    	punctuation_ratio(sentence),		
        #  9 --- Mean number of characters per word.
    	average_word_length(sentence),	
        # 10 --- Total number of syllables.
    	syllable_count(sentence),			
        # 11 --- Total number of words 
    	word_count(sentence),				
        # 12 --- Count of distinct words in the sentence. 
    	unique_word_count(sentence),		
        # 13 --- How "expected" a sentence is, based on n-gram language models. 
    	bigram_probability(sentence),			
        # 14 --- How "expected" a sentence is, based on n-gram language models.
    	trigram_probability(sentence),			
        # 15 --- Number of pronouns / total words. 
    	pronoun_density(sentence),				
        # 16 --- % of verbs in passive voice. 
    	passive_voice_ratio(sentence),	
    	# 17 --- sentences that are named entities
    	get_named_entity(sentence),
    	# 18
    	textstat.smog_index(sentence),
    	# 19
    	textstat.coleman_liau_index(sentence),
    	# 20
    	textstat.automated_readability_index(sentence),
    	# 21
    	textstat.dale_chall_readability_score(sentence),
    	# 22
    	textstat.linsear_write_formula(sentence),
    	# 23
    	textstat.gunning_fog(sentence),	
    ]
    return tmp








def convertListOfSentences(X):
	return [analyze_sentence(x) for x in X]


def extractPreAttributionDataset(ds):
    X = []
    Y = []

    for query in ds:
        for sample in query[1]:
            X.append( analyze_sentence(sample[0]) )
            Y.append( sample[2] )

    return X, Y