#!/usr/bin/env python3
"""
Ground truth labels for nonlinearity detection evaluation.
This file contains the expected nonlinearity types for each problem in the LinearizeLLM dataset.
"""

# Ground truth labels for nonlinearity detection
# Format: problem_name -> list of expected nonlinearity types
GROUND_TRUTH_LABELS = {
    # Aircraft problems 
    'aircraft_monotone_max': ['monotone_transformation', 'max'],

    # Blend problems (typically have bilinear terms)
    'blend_problem_abs': ['absolute_value'],
    'blend_problem_bilin_concon': ['bilinear'],
    'blend_problem_frac': ['quotient'],
    
    # Diet problems 
    'diet_problem_min_abs': ['absolute_value', 'min'],
    'diet_problem_nonlinear_frac': ['quotient'],
    'diet_problem_nonlinear_frac_min': ['quotient', 'min'],
    'diet_problem_monotone': ['monotone_transformation'],

    # Knapsack problems
    'knapsack_problem_nonlinear_min_1': ['min'],
    'knapsack_problem_nonlinear_min_2': ['min'],

    # Media selection problems
    'media_selection_nonlinear_binbin': ['bilinear'],
    'media_selection_nonlinear_bincon': ['bilinear'],

    # Multi-product network problems
    'multi_nonlinear_abs': ['absolute_value'],
    'multi_nonlinear_bincon': ['bilinear'],

    # Project assignment problems
    'netasgn_nonlinear_abs': ['absolute_value'],
    'netasgn_nonlinear_max': ['max'],

    # Network problems
    'netmcol_nonlinear_bincon': ['bilinear'],
    'netmcol_nonlinear_frac': ['quotient'],

    # Transportation problems
    'nltrans_nonlinear_bincon': ['bilinear'],
    'nltrans_nonlinear_max': ['max'],

    # Production planning problems
    'prod_nonlinear_bincon': ['bilinear'],
    'prod_nonlinear_max': ['max'],

    # Revenue management problems
    'revenue_maximization_nonlinear_bincon': ['bilinear'],

}

# Nonlinearity type descriptions
NONLINEARITY_TYPES = {
    'linear': 'No nonlinear terms - problem is already linear',
    'bilinear': 'Contains bilinear terms (product of two variables)',
    'min': 'Contains min functions',
    'max': 'Contains max functions',
    'absolute_value': 'Contains absolute value functions',
    'quotient': 'Contains quotient/division terms',
    'monotone_transformation': 'Contains monotone transformations',
    'quadratic': 'Contains quadratic terms',
    'mixed': 'Contains multiple types of nonlinearities'
}

def get_ground_truth(problem_name: str) -> list:
    """
    Get ground truth nonlinearity types for a given problem.
    
    Args:
        problem_name: Name of the problem
        
    Returns:
        List of expected nonlinearity types
    """
    return GROUND_TRUTH_LABELS.get(problem_name, ['unknown'])

def get_nonlinearity_description(nonlinearity_type: str) -> str:
    """
    Get description for a nonlinearity type.
    
    Args:
        nonlinearity_type: Type of nonlinearity
        
    Returns:
        Description of the nonlinearity type
    """
    return NONLINEARITY_TYPES.get(nonlinearity_type, 'Unknown nonlinearity type')

def evaluate_detection_accuracy(problem_name: str, detected_patterns: str) -> dict:
    """
    Evaluate detection accuracy against ground truth.
    
    Args:
        problem_name: Name of the problem
        detected_patterns: String containing detected patterns from LLM
        
    Returns:
        Dictionary with evaluation results
    """
    ground_truth = get_ground_truth(problem_name)
    
    # Parse detected patterns with more precise logic
    detected_types = []
    detected_lower = detected_patterns.lower()
    
    # Look for specific patterns that indicate actual detection
    # The LLM uses section headers like "BILINEAR_PATTERNS:" followed by patterns
    
    # Bilinear detection - look for section headers and actual mathematical patterns
    if ('bilinear_patterns:' in detected_lower or 'bilinear patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and '\\cdot' in detected_lower) or ('$' in detected_lower and 'x_' in detected_lower and 'y_' in detected_lower):
            detected_types.append('bilinear')
    
    # Min detection - look for section headers and actual mathematical patterns
    if ('min_patterns:' in detected_lower or 'min patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and '\\min' in detected_lower):
            detected_types.append('min')
    
    # Max detection - look for section headers and actual mathematical patterns
    if ('max_patterns:' in detected_lower or 'max patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and '\\max' in detected_lower):
            detected_types.append('max')
    
    # Absolute value detection - look for section headers and actual mathematical patterns
    if ('absolute_patterns:' in detected_lower or 'absolute value patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and '|' in detected_lower):
            detected_types.append('absolute_value')
    
    # Quotient detection - look for section headers and actual mathematical patterns
    if ('quotient_patterns:' in detected_lower or 'quotient patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and ('\\frac' in detected_lower or '/' in detected_lower)):
            detected_types.append('quotient')
    
    # Monotone transformation detection - look for section headers and actual mathematical patterns
    if ('monotone_transformation_patterns:' in detected_lower or 'monotone transformation patterns:' in detected_lower):
        # Check if there are actual mathematical patterns (not just "NONE")
        if ('$' in detected_lower and ('\\exp' in detected_lower or '\\sqrt' in detected_lower or '\\log' in detected_lower)):
            detected_types.append('monotone_transformation')
    
    # Also check for the old format patterns as fallback
    # Bilinear detection - look for specific patterns with numbers > 0
    if any(phrase in detected_lower for phrase in [
        'bilinear pattern 1:', 'bilinear pattern 2:', 'bilinear pattern 3:',
        '✅ bilinear pattern', 'bilinear pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('bilinear patterns: 0' in detected_lower or 'bilinear pattern: 0' in detected_lower):
            if 'bilinear' not in detected_types:
                detected_types.append('bilinear')
    
    # Min detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'min pattern 1:', 'min pattern 2:', 'min pattern 3:',
        '✅ min pattern', 'min pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('min patterns: 0' in detected_lower):
            if 'min' not in detected_types:
                detected_types.append('min')
    
    # Max detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'max pattern 1:', 'max pattern 2:', 'max pattern 3:',
        '✅ max pattern', 'max pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('max patterns: 0' in detected_lower):
            if 'max' not in detected_types:
                detected_types.append('max')
    
    # Absolute value detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'absolute_value pattern 1:', 'absolute_value pattern 2:', 'absolute_value pattern 3:',
        'absolute value pattern 1:', 'absolute value pattern 2:', 'absolute value pattern 3:',
        '✅ absolute value pattern', 'absolute value pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('absolute value patterns: 0' in detected_lower):
            if 'absolute_value' not in detected_types:
                detected_types.append('absolute_value')
    
    # Quotient detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'quotient pattern 1:', 'quotient pattern 2:', 'quotient pattern 3:',
        '✅ quotient pattern', 'quotient pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('quotient patterns: 0' in detected_lower):
            if 'quotient' not in detected_types:
                detected_types.append('quotient')
    
    # Monotone transformation detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'monotone_transformation pattern 1:', 'monotone_transformation pattern 2:', 'monotone_transformation pattern 3:',
        'monotone transformation pattern 1:', 'monotone transformation pattern 2:', 'monotone transformation pattern 3:',
        '✅ monotone_transformation pattern', '✅ monotone transformation pattern', 'monotone transformation pattern detected'
    ]):
        # Check if it's not just a mention in summary with count 0
        if not ('monotone transformation patterns: 0' in detected_lower):
            if 'monotone_transformation' not in detected_types:
                detected_types.append('monotone_transformation')
    
    # Quadratic detection - look for specific patterns
    if any(phrase in detected_lower for phrase in [
        'quadratic pattern', 'quadratic term', 'quadratic function',
        'quadratic constraint', 'quadratic objective', 'squared term'
    ]):
        detected_types.append('quadratic')
    
    # Linear detection - only if explicitly stated as linear
    if any(phrase in detected_lower for phrase in [
        'linear problem', 'no nonlinear', 'linear model', 'linear function',
        'linear constraint', 'linear objective', 'all linear'
    ]):
        detected_types.append('linear')
    
    # Evaluate accuracy
    correct_detections = set(ground_truth) & set(detected_types)
    false_positives = set(detected_types) - set(ground_truth)
    false_negatives = set(ground_truth) - set(detected_types)
    
    precision = len(correct_detections) / len(detected_types) if detected_types else 0
    recall = len(correct_detections) / len(ground_truth) if ground_truth else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'ground_truth': ground_truth,
        'detected_types': detected_types,
        'correct_detections': list(correct_detections),
        'false_positives': list(false_positives),
        'false_negatives': list(false_negatives),
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'is_correct': len(false_positives) == 0 and len(false_negatives) == 0
    }

def get_all_problem_names() -> list:
    """
    Get list of all problem names with ground truth labels.
    
    Returns:
        List of problem names
    """
    return list(GROUND_TRUTH_LABELS.keys())

def get_problems_by_nonlinearity_type(nonlinearity_type: str) -> list:
    """
    Get list of problems that should have a specific nonlinearity type.
    
    Args:
        nonlinearity_type: Type of nonlinearity to filter by
        
    Returns:
        List of problem names
    """
    return [name for name, types in GROUND_TRUTH_LABELS.items() 
            if nonlinearity_type in types]


def debug_detection_parsing(detected_patterns: str) -> dict:
    """
    Debug function to see what patterns are being detected.
    
    Args:
        detected_patterns: String containing detected patterns from LLM
        
    Returns:
        Dictionary with debug information
    """
    detected_lower = detected_patterns.lower()
    debug_info = {
        'original_text': detected_patterns,
        'lowercase_text': detected_lower,
        'pattern_matches': {}
    }
    
    # Check each pattern type
    pattern_types = {
        'bilinear': [
            'bilinear pattern', 'bilinear patterns', 'product of variables',
            'bilinear constraint', 'bilinear objective', 'bilinear function',
            'bilinear pattern 1:', 'bilinear pattern 2:', 'bilinear pattern 3:'
        ],
        'minmax': [
            'min/max pattern', 'min max pattern', 'minmax pattern',
            'min function', 'max function', 'min/max function',
            'minimum maximum', 'min max function', 'min/max patterns',
            'min max patterns', 'minmax patterns', 'min functions', 'max functions',
            'minmax pattern 1:', 'minmax pattern 2:', 'minmax pattern 3:',
            'min/max pattern 1:', 'min/max pattern 2:', 'min/max pattern 3:'
        ],
        'absolute_value': [
            'absolute value pattern', 'absolute value patterns', 'abs(',
            'absolute value constraint', 'absolute value objective',
            'absolute_value pattern 1:', 'absolute_value pattern 2:', 'absolute_value pattern 3:',
            'absolute value pattern 1:', 'absolute value pattern 2:', 'absolute value pattern 3:'
        ],
        'quotient': [
            'quotient pattern', 'quotient term', 'division pattern',
            'fraction pattern', 'ratio pattern', 'quotient function',
            'quotient constraint', 'quotient objective', 'quotient patterns',
            'quotient terms', 'division patterns', 'fraction patterns',
            'quotient pattern 1:', 'quotient pattern 2:', 'quotient pattern 3:'
        ],
        'monotone_transformation': [
            'monotone transformation pattern', 'monotone transformation patterns',
            'monotone constraint', 'monotone objective', 'monotone transformation function',
            'monotone_transformation pattern 1:', 'monotone_transformation pattern 2:', 'monotone_transformation pattern 3:',
            'monotone transformation pattern 1:', 'monotone transformation pattern 2:', 'monotone transformation pattern 3:'
        ],
        'quadratic': [
            'quadratic pattern', 'quadratic term', 'quadratic function',
            'quadratic constraint', 'quadratic objective', 'squared term'
        ],
        'linear': [
            'linear problem', 'no nonlinear', 'linear model', 'linear function',
            'linear constraint', 'linear objective', 'all linear'
        ]
    }
    
    for pattern_type, phrases in pattern_types.items():
        matches = [phrase for phrase in phrases if phrase in detected_lower]
        debug_info['pattern_matches'][pattern_type] = matches
    
    return debug_info 