import pandas as pd
from textblob import TextBlob
from textstat import flesch_kincaid_grade
from sklearn.feature_extraction.text import CountVectorizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import LatentDirichletAllocation
import numpy as np
import re
from collections import Counter
from transformers import AutoTokenizer, pipeline
from scipy.spatial.distance import euclidean
import math
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

def is_only_stopwords(text):
    words = text.lower().split()
    return all(word in ENGLISH_STOP_WORDS for word in words)

def clean_dataset(dataset):
    cleaned = []
    for example in dataset:
        text = example.get('text', '')
        output = example.get('output', '')

        if not isinstance(text, str) or not isinstance(output, str):
            continue
        if not text.strip() or not output.strip():
            continue
        if is_only_stopwords(text) or is_only_stopwords(output):
            continue

        cleaned.append(example)
    return cleaned


# Function to calculate token count
def count_tokens(text, tokenizer):
    return len(tokenizer.encode(text))

# Function to calculate cosine similarity between two texts
def calculate_cosine_similarity(text1, text2, model):
    emb1 = model.encode(text1)
    emb2 = model.encode(text2)
    return cosine_similarity([emb1], [emb2])[0][0]

# Function to calculate sentiment score using TextBlob
def calculate_sentiment(text):
    return TextBlob(text).sentiment.polarity

# Function to calculate readability score using Flesch-Kincaid Grade
def calculate_readability(text):
    return flesch_kincaid_grade(text)

# Function to calculate lexical diversity
def calculate_lexical_diversity(texts):
    words = ' '.join(texts).split()
    return len(set(words)) / len(words)

# Function to calculate out-of-vocabulary (OOV) rate
def calculate_oov_rate(texts, tokenizer):
    oov_count = 0
    total_words = 0
    for text in texts:
        tokens = tokenizer.encode(text)
        oov_count += sum([1 for token in tokens if token == tokenizer.unk_token_id])
        total_words += len(tokens)
    return oov_count / total_words

# Function to perform topic modeling using LDA
def perform_lda(texts, n_topics=5):
    vectorizer = CountVectorizer(stop_words='english')
    X = vectorizer.fit_transform(texts)
    lda = LatentDirichletAllocation(n_components=n_topics, random_state=42)
    lda.fit(X)
    topics = []
    for topic_idx, topic in enumerate(lda.components_):
        topic_words = [vectorizer.get_feature_names_out()[i] for i in topic.argsort()[-10:]]
        topics.append(f"Topic {topic_idx}: {', '.join(topic_words)}")
    return topics

# === ADDITIONAL METRICS ===

# Function to calculate Type-Token Ratio (TTR)
def calculate_type_token_ratio(text):
    words = text.split()
    unique_words = set(words)
    return len(unique_words) / len(words)

# Function to calculate toxicity score using a pre-trained model
toxicity_model = pipeline("text-classification", model="unitary/toxic-bert")
def calculate_toxicity(text):
    max_len = 512
    tokens = toxicity_model.tokenizer.encode(text, truncation=True, max_length=max_len, return_tensors="pt")
    decoded_text = toxicity_model.tokenizer.decode(tokens[0], skip_special_tokens=True)
    result = toxicity_model(decoded_text)
    return result[0]['score']

# Function to calculate Euclidean distance between embeddings
def calculate_euclidean_distance(text1, text2, model):
    emb1 = model.encode(text1)
    emb2 = model.encode(text2)
    return euclidean(emb1, emb2)

# Function to calculate KL Divergence between two texts
def calculate_kl_divergence(text1, text2):
    vectorizer = CountVectorizer(stop_words='english')

    try:
        # Try to fit on both texts at once to get a shared vocab
        X = vectorizer.fit_transform([text1, text2]).toarray()

        if X.shape[1] == 0:
            return float('nan')  # No vocab = can't compute divergence

        X1 = X[0] + 1e-10
        X2 = X[1] + 1e-10

        # Normalize
        X1 /= X1.sum()
        X2 /= X2.sum()

        return np.sum(X1 * np.log(X1 / X2))

    except ValueError:
        return float('nan')  # Fall back if something still breaks


# Main analysis function
def analyze_dataset(dataset, tokenizer, model):
    results = []
    for example in dataset:
        # Extract questions and responses
        text = example['text']
        response = example['output']

        # Token count
        num_tokens_question = count_tokens(text, tokenizer)
        num_tokens_response = count_tokens(response, tokenizer)

        # Cosine similarity
        semantic_similarity = calculate_cosine_similarity(text, response, model)

        # Sentiment analysis
        sentiment_question = calculate_sentiment(text)
        sentiment_response = calculate_sentiment(response)

        # Readability score
        readability_question = calculate_readability(text)
        readability_response = calculate_readability(response)

        # Type-Token Ratio
        ttr_question = calculate_type_token_ratio(text)
        ttr_response = calculate_type_token_ratio(response)

        # Toxicity score
        toxicity_question = calculate_toxicity(text)
        toxicity_response = calculate_toxicity(response)

        # Euclidean Distance
        euclidean_distance = calculate_euclidean_distance(text, response, model)

        # KL Divergence
        kl_divergence = calculate_kl_divergence(text, response)

        results.append({
            'text': text,
            'response': response,
            'num_tokens_question': num_tokens_question,
            'num_tokens_response': num_tokens_response,
            'semantic_similarity': semantic_similarity,
            'sentiment_question': sentiment_question,
            'sentiment_response': sentiment_response,
            'readability_question': readability_question,
            'readability_response': readability_response,
            'ttr_question': ttr_question,
            'ttr_response': ttr_response,
            'toxicity_question': toxicity_question,
            'toxicity_response': toxicity_response,
            'euclidean_distance': euclidean_distance,
            'kl_divergence': kl_divergence,
        })

    return pd.DataFrame(results)

# Load dataset
from datasets import load_dataset
dataset = load_dataset("")['train']  # ADD DATASET HERE

# Load pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") 
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

# Perform analysis
dataset = clean_dataset(dataset)
df = analyze_dataset(dataset, tokenizer, sentence_model)

# Output the results to a CSV file
df.to_csv("filename.csv", index=False)

# Output the results to a TXT file (formatted)
with open("filename.txt", "w", encoding="utf-8") as file:
    file.write("Analysis Results:\n\n")
    for index, row in df.iterrows():
        file.write(f"Question: {row['text']}\n")
        file.write(f"Response: {row['response']}\n")
        file.write(f"Num Tokens (Question): {row['num_tokens_question']}\n")
        file.write(f"Num Tokens (Response): {row['num_tokens_response']}\n")
        file.write(f"Semantic Similarity: {row['semantic_similarity']}\n")
        file.write(f"Sentiment (Question): {row['sentiment_question']}\n")
        file.write(f"Sentiment (Response): {row['sentiment_response']}\n")
        file.write(f"Readability (Question): {row['readability_question']}\n")
        file.write(f"Readability (Response): {row['readability_response']}\n")
        file.write(f"Type-Token Ratio (Question): {row['ttr_question']}\n")
        file.write(f"Type-Token Ratio (Response): {row['ttr_response']}\n")
        file.write(f"Toxicity (Question): {row['toxicity_question']}\n")
        file.write(f"Toxicity (Response): {row['toxicity_response']}\n")
        file.write(f"Euclidean Distance: {row['euclidean_distance']}\n")
        file.write(f"KL Divergence: {row['kl_divergence']}\n")
        file.write("\n---\n\n")

# Calculate Lexical Diversity
texts = [ex['text'] for ex in dataset]
outputs = [ex['output'] for ex in dataset]

lexical_diversity_question = calculate_lexical_diversity(texts)
lexical_diversity_response = calculate_lexical_diversity(outputs)

with open("filename.txt", "a") as file:
    file.write(f"Lexical Diversity (Questions): {lexical_diversity_question}\n")
    file.write(f"Lexical Diversity (Responses): {lexical_diversity_response}\n")

# Calculate OOV rate
oov_rate_question = calculate_oov_rate(texts, tokenizer)
oov_rate_response = calculate_oov_rate(outputs, tokenizer)

with open("filename.txt", "a") as file:
    file.write(f"OOV Rate (Questions): {oov_rate_question}\n")
    file.write(f"OOV Rate (Responses): {oov_rate_response}\n")

# Perform topic modeling
topics = perform_lda(texts, n_topics=5)
with open("filename.txt", "a") as file:
    file.write("Topic Modeling Results:\n")
    for topic in topics:
        file.write(f"{topic}\n")

summary_stats = df.describe().T
summary_stats['range'] = summary_stats['max'] - summary_stats['min']

# Save summary stats to a TXT file
with open("filename_summary.txt", "w", encoding="utf-8") as file:
    file.write("Summary Statistics:\n\n")
    for metric, stats in summary_stats.iterrows():
        file.write(f"{metric}:\n")
        file.write(f"  Mean: {stats['mean']:.4f}\n")
        file.write(f"  Std Dev: {stats['std']:.4f}\n")
        file.write(f"  Min: {stats['min']:.4f}\n")
        file.write(f"  Max: {stats['max']:.4f}\n")
        file.write(f"  Range: {stats['range']:.4f}\n")
        file.write("\n")
