"""
Evaluation metrics for natural language processing and question answering tasks.

This module implements fuzzy accuracy measures based on Wu-Palmer semantic similarity
and WUPS (Wu-Palmer Similarity) scores for evaluating answer quality in QA systems.
These metrics provide more nuanced evaluation than exact string matching by considering
semantic similarity between predicted and ground truth answers.
"""

import sys
import string
from numpy import prod
from nltk.corpus import wordnet as wn


def remove_punctuation(text):
    """
    Remove all punctuation characters from the input text.
    
    Uses Python's string.punctuation to identify and remove punctuation marks,
    which is useful for text normalization in evaluation metrics.
    
    Args:
        text (str): Input text containing punctuation
        
    Returns:
        str: Text with all punctuation characters removed
    """
    return text.translate(str.maketrans('', '', string.punctuation))

def items2list(x):
    """
    Convert space-separated string of items into a list.
    
    Splits the input string on spaces and strips whitespace from each item.
    This is commonly used to parse answer strings in QA evaluation.
    
    Args:
        x (str): Space-separated string of answer items
        
    Returns:
        list: List of individual answer items with whitespace stripped
    """
    return [l.strip() for l in x.split(' ')]

def fuzzy_set_membership_measure(x, A, m):
    """
    Compute fuzzy set membership measure for an element with respect to a set.
    
    This function implements a fuzzy set membership measure that computes
    the maximum similarity between an element x and any element in set A:
        m(x ∈ A) = max_{a ∈ A} m(x,a)
    
    Args:
        x: Element to test for membership
        A (list): Set of elements to compare against
        m (callable): Point-wise similarity measure function m(a,b) → similarity score
        
    Returns:
        float: Maximum similarity score between x and any element in A, or 0 if A is empty
    """
    return 0 if A == [] else max(map(lambda a: m(x, a), A))

def score_it(A, T, m):
    """
    Compute fuzzy accuracy score between two sets using a membership measure.
    
    This function implements a symmetric fuzzy accuracy score that ensures
    both precision and recall are considered:
        score(A,T) = min{∏_{a ∈ A} m(a ∈ T), ∏_{t ∈ T} m(t ∈ A)}
    
    The score is the minimum of:
    - Product of membership scores of A elements in T (precision-like)
    - Product of membership scores of T elements in A (recall-like)
    
    Args:
        A (list): List of predicted answer items
        T (list): List of ground truth answer items  
        m (callable): Set membership measure function
        
    Returns:
        float: Fuzzy accuracy score between 0 and 1, where 1 indicates perfect match
    """
    """
    Compute fuzzy accuracy score between two sets using a membership measure.
    
    This function implements a symmetric fuzzy accuracy score that ensures
    both precision and recall are considered:
        score(A,T) = min{∏_{a ∈ A} m(a ∈ T), ∏_{t ∈ T} m(t ∈ A)}
    
    The score is the minimum of:
    - Product of membership scores of A elements in T (precision-like)
    - Product of membership scores of T elements in A (recall-like)
    
    Args:
        A (list): List of predicted answer items
        T (list): List of ground truth answer items  
        m (callable): Set membership measure function
        
    Returns:
        float: Fuzzy accuracy score between 0 and 1, where 1 indicates perfect match
    """
    if A == [] and T == []:
        return 1

    score_left = 0 if A == [] else prod(list(map(lambda a: m(a, T), A)))
    score_right = 0 if T == [] else prod(list(map(lambda t: m(t, A), T)))
    return min(score_left, score_right)


def dirac_measure(a, b):
    """
    Compute Dirac delta measure for exact string matching.
    
    This is the strictest similarity measure that returns 1 if and only if
    the two strings are exactly equal, and 0 otherwise.
    
    Args:
        a (str): First string to compare
        b (str): Second string to compare
        
    Returns:
        float: 1.0 if strings are identical, 0.0 otherwise
    """
    if a == [] or b == []:
        return 0.0
    return float(a == b)

def wup_measure(a, b, similarity_threshold=0.925):
    """
    Compute Wu-Palmer semantic similarity measure between two words.
    
    This function implements the Wu-Palmer similarity measure using WordNet
    synsets to capture semantic relationships between words. The measure
    computes the maximum Wu-Palmer similarity across all possible synset pairs.
    
    The Wu-Palmer similarity is defined as:
        wup(s1, s2) = 2 * depth(lcs(s1, s2)) / (depth(s1) + depth(s2))
    where lcs is the least common subsumer in the WordNet hierarchy.
    
    Args:
        a (str): First word to compare
        b (str): Second word to compare  
        similarity_threshold (float): Threshold above which semantic similarity
                                    is considered strong (default: 0.925)
    
    Returns:
        float: Wu-Palmer similarity score between 0 and 1, where 1 indicates
               perfect semantic similarity
    """
    """
    Compute Wu-Palmer semantic similarity measure between two words.
    
    This function implements the Wu-Palmer similarity measure using WordNet
    synsets to capture semantic relationships between words. The measure
    computes the maximum Wu-Palmer similarity across all possible synset pairs.
    
    The Wu-Palmer similarity is defined as:
        wup(s1, s2) = 2 * depth(lcs(s1, s2)) / (depth(s1) + depth(s2))
    where lcs is the least common subsumer in the WordNet hierarchy.
    
    Args:
        a (str): First word to compare
        b (str): Second word to compare  
        similarity_threshold (float): Threshold above which semantic similarity
                                    is considered strong (default: 0.925)
    
    Returns:
        float: Wu-Palmer similarity score between 0 and 1, where 1 indicates
               perfect semantic similarity
    """
    def get_semantic_field(a):
        """
        Extract semantic field (synsets) for a word from WordNet.
        
        Args:
            a (str): Word to get synsets for
            
        Returns:
            tuple: (synsets_list, weight) where weight is always 1.0
        """
        weight = 1.0
        semantic_field = wn.synsets(a, pos=wn.NOUN)
        return (semantic_field, weight)

    def get_stem_word(a):
        """
        Extract stem word and compute weight.
        
        Handles cases where answer has form "word\\d+:wordid".
        Currently returns the word as-is with full weight.
        
        Args:
            a (str): Input word (potentially with special formatting)
            
        Returns:
            tuple: (processed_word, weight) where weight is 1.0
        """
        weight = 1.0
        return (a, weight)

    global_weight = 1.0

    # Extract stem words and compute global weight
    (a, global_weight_a) = get_stem_word(a)
    (b, global_weight_b) = get_stem_word(b)
    global_weight = min(global_weight_a, global_weight_b)

    # Handle exact string match
    if a == b:
        return 1.0 * global_weight

    # Handle empty inputs
    if a == [] or b == []:
        return 0

    # Get semantic fields (synsets) for both words
    interp_a, weight_a = get_semantic_field(a)
    interp_b, weight_b = get_semantic_field(b)

    # Handle cases where no synsets are found
    if interp_a == [] or interp_b == []:
        return 0

    # Find maximum Wu-Palmer similarity across all synset pairs
    global_max = 0.0
    for x in interp_a:
        for y in interp_b:
            local_score = x.wup_similarity(y)
            if local_score > global_max:
                global_max = local_score

    # Apply interpretation weight based on similarity threshold
    # High similarity indicates synonyms, so full weight is applied
    # Low similarity gets downweighted as it may be spurious
    if global_max < similarity_threshold:
        interp_weight = 0.1
    else:
        interp_weight = 1.0

    final_score = global_max * weight_a * weight_b * interp_weight * global_weight
    return final_score

def wups_score(input_gt, input_pred, thresh=0.9):
    """
    Compute WUPS (Wu-Palmer Similarity) score between ground truth and prediction.
    
    WUPS is a fuzzy accuracy metric that uses Wu-Palmer semantic similarity
    to evaluate answer quality in question-answering tasks. It provides more
    nuanced evaluation than exact string matching by considering semantic
    relationships between words.
    
    Args:
        input_gt (str): Ground truth answer string
        input_pred (str): Predicted answer string  
        thresh (float): Wu-Palmer similarity threshold. Use -1 for exact matching,
                       or a value between 0 and 1 for fuzzy matching (default: 0.9)
    
    Returns:
        float: WUPS score between 0 and 1, where 1 indicates perfect match
        
    Examples:
        >>> wups_score("cat", "feline", 0.9)  # High semantic similarity
        >>> wups_score("red car", "crimson automobile", 0.8)  # Fuzzy matching
        >>> wups_score("dog", "dog", -1)  # Exact matching
    """
    """
    Compute WUPS (Wu-Palmer Similarity) score between ground truth and prediction.
    
    WUPS is a fuzzy accuracy metric that uses Wu-Palmer semantic similarity
    to evaluate answer quality in question-answering tasks. It provides more
    nuanced evaluation than exact string matching by considering semantic
    relationships between words.
    
    Args:
        input_gt (str): Ground truth answer string
        input_pred (str): Predicted answer string  
        thresh (float): Wu-Palmer similarity threshold. Use -1 for exact matching,
                       or a value between 0 and 1 for fuzzy matching (default: 0.9)
    
    Returns:
        float: WUPS score between 0 and 1, where 1 indicates perfect match
        
    Examples:
        >>> wups_score("cat", "feline", 0.9)  # High semantic similarity
        >>> wups_score("red car", "crimson automobile", 0.8)  # Fuzzy matching
        >>> wups_score("dog", "dog", -1)  # Exact matching
    """
    # Select similarity measure based on threshold
    if thresh == -1:
        our_element_membership = dirac_measure
    else:
        our_element_membership = lambda x, y: wup_measure(x, y, thresh)

    # Create set membership measure using selected element measure
    our_set_membership = lambda x, A: fuzzy_set_membership_measure(x, A, our_element_membership)

    # Print evaluation mode information
    if thresh == -1:
        print('standard Accuracy is used')
    else:
        print('soft WUPS at %1.2f is used' % thresh)

    # Preprocess inputs: remove punctuation and convert to lists
    input_gt = remove_punctuation(input_gt)
    input_pred = remove_punctuation(input_pred)

    gt_list = items2list(input_gt)
    pred_list = items2list(input_pred)

    # Compute fuzzy accuracy score
    score_list = [score_it(gt_list, pred_list, our_set_membership)]
    final_score = float(sum(score_list)) / float(len(score_list))
    
    # Print final score
    print('final score of WUPS_%1.2f is %2.2f%%' % (thresh, final_score * 100.0))
    return final_score
