# functions for evaluating the output from divpo/loopo

import pandas as pd
import numpy as np
import gzip
import json
import re
from ast import literal_eval
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
import argparse



parser = argparse.ArgumentParser()
parser.add_argument('--output', type=str, default="output.csv")
parser.add_argument('--input', type=str, default="input.csv")
parser.add_argument('--column', type=str, default="batch_generations")

args = parser.parse_args()
# Load dataset
df = pd.read_csv(args.input)

# Load embedding model
embedder = SentenceTransformer("all-MiniLM-L6-v2")

### helper functions
def clean_poem(text: str) -> str:
    """Remove header/prefix and trailing hashes; normalize whitespace & lowercase."""
    s = str(text)
    # drop everything up to the first blank line (header like "Here is a 4-line poem...")
    parts = s.split("\n\n", 1)
    s = parts[1] if len(parts) > 1 else s
    # remove trailing ##### markup per line
    s = re.sub(r"#.*$", "", s, flags=re.MULTILINE)
    # collapse whitespace and lowercase
    s = re.sub(r"\s+", " ", s).strip().lower()
    return s

def compute_unique_ngrams_any_n(samples, n_min=1, n_max=5):
    """Unique n-grams counted over all n in [n_min, n_max], union across samples."""
    if not samples: 
        return np.nan
    samples = [str(s) for s in samples if pd.notna(s)]
    vectorizer = CountVectorizer(
        analyzer="word",
        ngram_range=(n_min, n_max),
        token_pattern=r"(?u)\b\w+\b"  # consistent with your current 1-gram func
    )
    X = vectorizer.fit_transform(samples)
    return (X.sum(axis=0).A1 > 0).sum()

def compute_unique_ngrams_by_n(samples, n_list=(1,2,3,4,5)):
    """Per-n unique counts as a dict {n: count}."""
    if not samples:
        return {n: np.nan for n in n_list}
    samples = [str(s) for s in samples if pd.notna(s)]
    out = {}
    for n in n_list:
        out[n] = compute_unique_ngrams(samples, n=n)
    return out

def norm_item(x: str) -> str:
    """Normalize an item for exact-identity counting:
    - strip whitespace
    - collapse inner whitespace
    - lowercase
    (keeps multi-word items like 'data scientist' as ONE item)"""
    s = str(x)
    s = re.sub(r"\s+", " ", s).strip().lower()
    return s

def compute_unique_items(samples) -> int | float:
    """Count unique items (strings), after normalization."""
    if not samples:
        return np.nan
    items = [norm_item(s) for s in samples if pd.notna(s)]
    return len(set(items))

def compute_unique_items_split_persona(names, occs, cities):
    """Return [total_unique, unique_first_names, unique_occupations, unique_cities]."""
    names = [norm_item(x) for x in names if x]
    occs  = [norm_item(x) for x in occs  if x]
    cities = [norm_item(x) for x in cities if x]
    total = compute_unique_items(names + occs + cities)
    return [total, compute_unique_items(names), compute_unique_items(occs), compute_unique_items(cities)]

def extract_keywords_strict(text: str):
    """
    Extracts single-word keywords from a generation:
    - strips anything after '#' runs
    - splits on any newlines
    - lowercases, trims
    - drops headings like 'the eyes of the world' and lines starting with 'here are'
    - keeps ONLY single alphabetic tokens (no spaces, no punctuation)
    """
    s = re.split(r'#+', str(text).strip())[0]
    parts = [p.strip().lower() for p in re.split(r'\n+', s) if p.strip()]
    drop_exact = {"the eyes of the world"}
    out = []
    for p in parts:
        if p in drop_exact:
            continue
        if p.startswith("here are"):   # drop heading/sentence lines
            continue
        if re.fullmatch(r"[a-z]+", p): # keep one alphabetic token only
            out.append(p)
    return out


def _vectorize_unigrams(samples):
    """Return (X, vocab_mask) where:
       - X is (docs x vocab) count matrix
       - vocab_mask is boolean mask of columns that appear at least once
    """
    if not samples:
        return None, None
    samples = [str(s) for s in samples if pd.notna(s)]
    if len(samples) == 0:
        return None, None
    vectorizer = CountVectorizer(
        analyzer='word',
        ngram_range=(1, 1),
        token_pattern=r"(?u)\b\w+\b"
    )
    X = vectorizer.fit_transform(samples)
    vocab_mask = (X.sum(axis=0).A1 > 0)
    return X, vocab_mask




def compute_unique_1grams_normalized(samples, method="ttr"):
    """
    Length-normalized unique-1-grams across all samples.
    Methods:
      - 'ttr'    : V / N (Type–Token Ratio)
      - 'herdan' : log(V) / log(N)
      - 'guiraud': V / sqrt(N)
      - 'maas'   : (log(N) - log(V)) / (log(N)**2)   (lower is 'more diverse')
    Returns np.nan if no tokens.
    """
    X, vocab_mask = _vectorize_unigrams(samples)
    if X is None:
        return np.nan

    V = vocab_mask.sum()                    # number of unique types
    N = int(X.sum())                        # total tokens across samples
    if N == 0 or V == 0:
        # Define consistent edge cases
        if method in {"ttr", "guiraud"}:
            return 0.0
        elif method == "herdan":
            return 0.0
        elif method == "maas":
            return 0.0
        return 0.0

    if method == "ttr":
        return V / N
    elif method == "herdan":
        return np.log(V) / np.log(N)
    elif method == "guiraud":
        return V / np.sqrt(N)
    elif method == "maas":
        return (np.log(N) - np.log(V)) / (np.log(N) ** 2)
    else:
        raise ValueError(f"Unknown method: {method}")



# --- Metric functions ---
def compute_embedding_variance(samples):
    if not samples: return np.nan
    embeddings = embedder.encode(samples)
    sim_matrix = cosine_similarity(embeddings)
    upper = np.triu_indices_from(sim_matrix, k=1)
    return 1 - sim_matrix[upper].mean()

def compute_unique_ngrams(samples, n=1):
    """Generic n-gram counter: strings or tokens, including single characters."""
    if not samples: return np.nan
    samples = [str(s) for s in samples if pd.notna(s)]
    vectorizer = CountVectorizer(
        analyzer='word',
        ngram_range=(n, n),
        token_pattern=r"(?u)\b\w+\b"  # allows single-character tokens like '7'
    )
    X = vectorizer.fit_transform(samples)
    return (X.sum(axis=0).A1 > 0).sum()


def compute_unique_ngrams_split_persona(names, occs, cities, n=1):
    names = [str(x) for x in names if x]
    occs = [str(x) for x in occs if x]
    cities = [str(x) for x in cities if x]
    total = compute_unique_ngrams(names + occs + cities, n=n)
    return [
        total,
        compute_unique_ngrams(names, n=n),
        compute_unique_ngrams(occs, n=n),
        compute_unique_ngrams(cities, n=n)
    ]


def compute_compression_ratio(samples):
    if not samples: return np.nan
    text = "\n".join(samples)
    orig_size = len(text.encode('utf-8'))
    comp_size = len(gzip.compress(text.encode('utf-8')))
    return comp_size / orig_size if orig_size else 0

def compute_token_entropy(samples):
    if not samples: return np.nan
    tokens = " ".join(samples).split()
    counts = pd.Series(tokens).value_counts(normalize=True)
    return -np.sum(counts * np.log2(counts))

# length-normalized entropy function


def compute_token_entropy_normalized(samples):
    if not samples:
        return np.nan
    tokens = " ".join(samples).split()
    counts = pd.Series(tokens).value_counts(normalize=True)
    H = -np.sum(counts * np.log2(counts))   # Shannon entropy
    V = len(counts)                         # vocabulary size
    return H / np.log2(V) if V > 1 else 0.0



def compute_l2_from_reference(samples, reference):
    if not samples or reference is None:
        return np.nan
    samples = [str(s) for s in samples]
    reference = str(reference)
    sample_embeds = embedder.encode(samples)
    ref_embed = embedder.encode([reference])[0]
    return np.linalg.norm(sample_embeds - ref_embed, axis=1).mean()




def extract_keywords(text):
    # Remove everything after the first sequence of '#' symbols
    text = re.split(r'#+', str(text).strip())[0]
    # Extract only words (letters only)
    return re.findall(r"[A-Za-z]+", text)


def extract_persona_fields(text):
    try:
        match = re.search(r"```(?:json)?\s*({.*?})\s*```", text, re.DOTALL)
        if match:
            text = match.group(1)
        else:
            match = re.search(r"{.*?}", text, re.DOTALL)
            if match:
                text = match.group(0)
        data = json.loads(text)
        return data.get("first_name"), data.get("occupation"), data.get("city")
    except:
        return None, None, None

# --- Evaluation loop ---
E1_list, E2_list, E3_list, E4_list, E5_list, E6_list = [], [], [], [], [], []

for _, row in tqdm(df.iterrows(), total=len(df)):

    try:
        gens = literal_eval(row[args.column])
    except:
        gens = []

    # Init
    e1 = e2 = e3 = e4 = e5 = e6 = np.nan

    try:
        e1 = compute_embedding_variance(gens)
        e2 = compute_unique_1grams_normalized(gens, method="ttr")
        e3 = compute_compression_ratio(gens)
        e4 = compute_token_entropy_normalized(gens)
    except Exception:
        pass

    # Append
    E1_list.append(e1)
    E2_list.append(e2)
    E3_list.append(e3)
    E4_list.append(e4)

# Save results
df["E1_lexical_diversity"] = E1_list
df["E2_unique_1grams"] = E2_list
df["E3_compression_ratio"] = E3_list
df["E4_entropy"] = E4_list

df.to_csv(args.output, index=False)
print("output saved")
