#!/usr/bin/env python3
"""
Comprehensive analysis of model performance improvements with fixed bugs.
Analyzes completeness, formatting, and reasoning improvements.
"""

import json
import re
import sys
from pathlib import Path
from collections import defaultdict
import numpy as np

# Import the existing parser
sys.path.append('.')
from parser import parse_answer

def is_likely_truncated_fixed(response):
    """Fixed version of truncation detection."""
    response = response.strip()
    
    # Common truncation indicators
    if response.endswith('...'):
        return True
    
    # Check if response is long but doesn't end with typical completion markers
    if len(response) > 500:  # Fixed: lowered from 800
        last_char = response[-1] if response else ''
        last_word = response.split()[-1] if response.split() else ''
        
        # Typical completion endings
        good_endings = ['.', ')', '}', ']', '!', '?']
        if last_char not in good_endings:
            # Check if it ends mid-sentence
            if not response.endswith('\\boxed{') and not last_word.endswith(':'):
                return True
    
    # Check for incomplete boxed answers
    if '\\boxed{' in response and '}' not in response[response.rfind('\\boxed{'):]:
        return True
    
    # Check for incomplete sentences at the end
    last_sentence = response.split('.')[-1].strip()
    if len(last_sentence) > 20 and not any(last_sentence.endswith(p) for p in ['.', ')', '}', '!', '?']):  # Fixed: lowered from 50
        return True
    
    return False

def extract_answer_strict(response, dataset):
    """Use the existing strict parser."""
    return parse_answer(response, dataset)

def extract_answer_relaxed(response, dataset):
    """Relaxed answer extraction that accepts various formats."""
    # First try strict
    strict_answer = extract_answer_strict(response, dataset)
    if strict_answer:
        return strict_answer
    
    # Determine if multiple choice
    is_mc = dataset in ['csqa', 'gpqa', 'mathqa']
    
    if is_mc:
        # Multiple choice patterns
        patterns = [
            r'\\boxed\{([A-E])\}',
            r'[Tt]he (?:correct )?answer is:?\s*([A-E])',
            r'[Aa]nswer:?\s*\(?([A-E])\)?',
            r'[Ss]o,?\s*(?:the )?(?:correct )?answer is:?\s*\(?([A-E])\)?',
            r'[Tt]herefore,?\s*(?:the )?(?:correct )?answer is:?\s*\(?([A-E])\)?',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response)
            if match:
                return match.group(1).upper()
        
        # Last resort: find the last single letter option
        all_mentions = re.findall(r'\b([A-E])\)', response)
        if all_mentions:
            return all_mentions[-1]
    else:
        # Numeric patterns
        patterns = [
            r'\\boxed\{([^}]+)\}',
            r'[Tt]he (?:final )?answer is:?\s*([0-9,./-]+)',
            r'[Aa]nswer:?\s*([0-9,./-]+)',
            r'= ([0-9,./-]+)$',  # End of line
        ]
        
        for pattern in patterns:
            matches = re.findall(pattern, response)
            if matches:
                return matches[-1].strip()
    
    return None

def analyze_file_comprehensive(filepath, dataset_name):
    """Comprehensive analysis of a single evaluation file."""
    if not Path(filepath).exists():
        return None
        
    with open(filepath, 'r') as f:
        data = json.load(f)
    
    total = len(data['results'])
    current_accuracy = data['accuracy']
    
    # Determine format
    is_mc = any(
        len(result['ground_truth'].strip()) == 1 and 
        result['ground_truth'].strip().upper() in 'ABCDE' 
        for result in data['results'][:10]
    )
    
    # Initialize counters
    stats = {
        'total': total,
        'current_accuracy': current_accuracy,
        'truncated_count': 0,
        'format_failures': 0,
        'strict_correct': 0,
        'relaxed_correct': 0,
        'truncated_recoverable': 0,
        'format_recoverable': 0
    }
    
    for result in data['results']:
        response = result['response']
        ground_truth = result['ground_truth']
        is_correct = result.get('correct', False)
        
        # Check truncation
        is_truncated = is_likely_truncated_fixed(response)
        if is_truncated:
            stats['truncated_count'] += 1
        
        # Check strict vs relaxed parsing
        strict_answer = extract_answer_strict(response, dataset_name)
        relaxed_answer = extract_answer_relaxed(response, dataset_name)
        
        # Count format failures
        if not strict_answer:
            stats['format_failures'] += 1
        
        # Count correct answers
        if strict_answer == ground_truth:
            stats['strict_correct'] += 1
        if relaxed_answer == ground_truth:
            stats['relaxed_correct'] += 1
        
        # Count recoverable cases
        if not is_correct and is_truncated and relaxed_answer == ground_truth:
            stats['truncated_recoverable'] += 1
        if not is_correct and not strict_answer and relaxed_answer == ground_truth:
            stats['format_recoverable'] += 1
    
    # Calculate percentages
    stats.update({
        'truncation_rate': stats['truncated_count'] / total * 100,
        'format_failure_rate': stats['format_failures'] / total * 100,
        'strict_accuracy': stats['strict_correct'] / total * 100,
        'relaxed_accuracy': stats['relaxed_correct'] / total * 100,
        'truncation_loss': stats['truncated_recoverable'] / total * 100,
        'format_loss': stats['format_recoverable'] / total * 100,
        'total_recoverable': (stats['truncated_recoverable'] + stats['format_recoverable']) / total * 100
    })
    
    return stats

def main():
    """Run comprehensive analysis on all model pairs."""
    print("Comprehensive Model Performance Analysis (Fixed)")
    print("=" * 80)
    
    # Define model pairs
    model_configs = [
        {
            'name': 'Qwen2B',
            'base_dir': 'q2b_sa_runs/qwen2b_base_eval',
            'trained_dir': 'q2b_sa_runs/qwen2b_all'
        },
        {
            'name': 'Llama3B',
            'base_dir': 'l3b_sa_runs/llama3b_base_eval', 
            'trained_dir': 'l3b_sa_runs/llama3b_all'
        }
    ]
    
    # Define datasets to analyze
    datasets = ['csqa', 'gpqa', 'gsm8k', 'math', 'mathqa', 'svamp', 'amc']
    
    all_results = defaultdict(dict)
    
    for config in model_configs:
        model_name = config['name']
        print(f"\n{model_name}")
        print("-" * 40)
        
        for dataset in datasets:
            base_file = f"{config['base_dir']}/evaluation_results_{dataset}.json"
            trained_file = f"{config['trained_dir']}/evaluation_results_{dataset}.json"
            
            if not (Path(base_file).exists() and Path(trained_file).exists()):
                continue
                
            base_stats = analyze_file_comprehensive(base_file, dataset)
            trained_stats = analyze_file_comprehensive(trained_file, dataset)
            
            if not (base_stats and trained_stats):
                continue
            
            all_results[model_name][dataset] = {
                'base': base_stats,
                'trained': trained_stats
            }
            
            # Calculate improvements
            total_improvement = trained_stats['current_accuracy'] - base_stats['current_accuracy']
            format_improvement = trained_stats['relaxed_accuracy'] - base_stats['relaxed_accuracy']
            pure_reasoning = format_improvement - (
                (trained_stats['total_recoverable'] - base_stats['total_recoverable'])
            )
            
            print(f"\n{dataset.upper()}:")
            print(f"  Base model:")
            print(f"    Accuracy: {base_stats['current_accuracy']*100:.1f}%")
            print(f"    Truncation: {base_stats['truncation_rate']:.1f}%")
            print(f"    Format failures: {base_stats['format_failure_rate']:.1f}%")
            print(f"    Recoverable loss: {base_stats['total_recoverable']:.1f}%")
            
            print(f"  Trained model:")
            print(f"    Accuracy: {trained_stats['current_accuracy']*100:.1f}%")
            print(f"    Truncation: {trained_stats['truncation_rate']:.1f}%")
            print(f"    Format failures: {trained_stats['format_failure_rate']:.1f}%")
            print(f"    Recoverable loss: {trained_stats['total_recoverable']:.1f}%")
            
            print(f"  Improvements:")
            print(f"    Total: {total_improvement*100:+.1f}%")
            print(f"    From reduced truncation: {base_stats['truncation_loss'] - trained_stats['truncation_loss']:+.1f}%")
            print(f"    From better formatting: {base_stats['format_loss'] - trained_stats['format_loss']:+.1f}%")
            print(f"    From reasoning: {(total_improvement - (base_stats['total_recoverable'] - trained_stats['total_recoverable'])/100)*100:+.1f}%")
    
    # Generate LaTeX table
    print("\n" + "=" * 80)
    print("LATEX TABLE GENERATION")
    print("=" * 80)
    
    latex = generate_latex_table(all_results)
    
    # Save to file
    with open('fixed_comprehensive_analysis.tex', 'w') as f:
        f.write(latex)
    
    print("LaTeX table saved to fixed_comprehensive_analysis.tex")

def generate_latex_table(all_results):
    """Generate comprehensive LaTeX table."""
    latex = r"""
\begin{table}[t]
\centering
\caption{Comprehensive Analysis of Model Performance Improvements (Fixed). We decompose accuracy gains into contributions from reduced truncation, improved formatting compliance, and enhanced reasoning. Truncation Rate shows the percentage of responses that appear incomplete. Format Failure Rate shows responses that don't follow the required \texttt{\textbackslash boxed\{\}} format. The Reasoning Improvement represents gains from better problem-solving after accounting for truncation and formatting issues.}
\label{tab:comprehensive_analysis_fixed}
\footnotesize
\begin{tabular}{ll|cccc|cccc|ccc}
\toprule
\multirow{2}{*}{\textbf{Model}} & \multirow{2}{*}{\textbf{Dataset}} & \multicolumn{4}{c}{\textbf{Base Model}} & \multicolumn{4}{c}{\textbf{Post-Training}} & \multicolumn{3}{c}{\textbf{Improvements}} \\
\cmidrule(lr){3-6} \cmidrule(lr){7-10} \cmidrule(lr){11-13}
& & Acc & Trunc & Fmt & Rec & Acc & Trunc & Fmt & Rec & Total & Form+Comp & Reasoning \\
& & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) & (\%) \\
\midrule
"""
    
    for model_name, model_data in all_results.items():
        first_row = True
        for dataset, stats in model_data.items():
            base = stats['base']
            trained = stats['trained']
            
            total_imp = trained['current_accuracy'] - base['current_accuracy']
            format_comp_imp = (base['total_recoverable'] - trained['total_recoverable'])
            reasoning_imp = total_imp - format_comp_imp/100  # Convert percentage back to accuracy units
            
            model_display = model_name if first_row else ""
            
            latex += f"{model_display} & {dataset.upper()} & "
            latex += f"{base['current_accuracy']*100:.1f} & {base['truncation_rate']:.1f} & {base['format_failure_rate']:.1f} & {base['total_recoverable']:.1f} & "
            latex += f"{trained['current_accuracy']*100:.1f} & {trained['truncation_rate']:.1f} & {trained['format_failure_rate']:.1f} & {trained['total_recoverable']:.1f} & "
            latex += f"{total_imp*100:+.1f} & {format_comp_imp:+.1f} & {reasoning_imp*100:+.1f} \\\\\n"
            
            first_row = False
        
        if model_name != list(all_results.keys())[-1]:
            latex += "\\midrule\n"
    
    latex += r"""
\bottomrule
\end{tabular}
\vspace{-0.1cm}
\begin{flushleft}
\footnotesize
\textbf{Note:} Acc = Accuracy, Trunc = Truncation rate, Fmt = Format failure rate, Rec = Recoverable loss (from both truncation and formatting), Form+Comp = Improvement from better formatting and completeness, Reasoning = Pure reasoning improvement after accounting for technical issues.
\end{flushleft}
\end{table}
"""
    
    return latex

if __name__ == "__main__":
    main()