#!/usr/bin/env python3

import subprocess
import os
import re
import glob
import json
from pathlib import Path

CACHE_FILE = 'parse_csv_cache.json'

def load_cache():
    """Load the cache from disk."""
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'r') as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Could not load cache file: {e}")
    return {}

def save_cache(cache):
    """Save the cache to disk."""
    try:
        with open(CACHE_FILE, 'w') as f:
            json.dump(cache, f, indent=2)
    except IOError as e:
        print(f"Warning: Could not save cache file: {e}")

def get_file_mtime(file_path):
    """Get the modification time of a file."""
    try:
        return os.path.getmtime(file_path)
    except OSError:
        return None

def run_parse_csv(csv_file, cache):
    """Run parse_csv.py on a CSV file and extract the percentages.
    
    Uses cache to avoid reprocessing files that haven't changed.
    """
    # Check if we have a valid cache entry
    mtime = get_file_mtime(csv_file)
    cache_key = os.path.abspath(csv_file)
    
    if cache_key in cache:
        cached_entry = cache[cache_key]
        if cached_entry.get('mtime') == mtime:
            # Cache hit - file hasn't changed
            print(f"  [CACHE HIT]")
            return cached_entry.get('coherence_aligned_pct'), cached_entry.get('coherence_low_pct')
    
    # Cache miss - need to process the file
    try:
        result = subprocess.run(
            ['python', 'parse_csv.py', csv_file], 
            capture_output=True, 
            text=True,
            cwd='.'
        )
        
        if result.returncode != 0:
            print(f"Error processing {csv_file}: {result.stderr}")
            return None, None
            
        output = result.stdout
        
        # Extract percentages using regex
        coherence_aligned_match = re.search(r'Percentage of rows with coherence > 50 and aligned < 30: ([\d.]+)%', output)
        coherence_low_match = re.search(r'Percentage of rows with coherence < 50: ([\d.]+)%', output)
        
        coherence_aligned_pct = coherence_aligned_match.group(1) if coherence_aligned_match else "N/A"
        coherence_low_pct = coherence_low_match.group(1) if coherence_low_match else "N/A"
        
        # Store in cache
        cache[cache_key] = {
            'mtime': mtime,
            'coherence_aligned_pct': coherence_aligned_pct,
            'coherence_low_pct': coherence_low_pct
        }
        
        return coherence_aligned_pct, coherence_low_pct
        
    except Exception as e:
        print(f"Exception processing {csv_file}: {e}")
        return None, None

def extract_adapter_name(file_path, rename_legal2=False):
    """Extract adapter name from file path.
    
    Args:
        file_path: Path to the CSV file
        rename_legal2: If True, rename 'legal2' to 'legal' in the output name
    """
    path_parts = Path(file_path).parts
    # Remove 'eval_result' from the path and combine directory and filename
    if 'eval_result' in path_parts:
        idx = path_parts.index('eval_result')
        if idx + 1 < len(path_parts):
            directory = path_parts[idx + 1]
            filename = Path(file_path).stem  # filename without extension
            # Remove '_in_domain' suffix from directory for matching
            if directory.endswith('_in_domain'):
                directory = directory[:-10]  # Remove '_in_domain'
            adapter_name = f"{directory}_{filename}"
            
            # Rename legal2 to legal if requested
            if rename_legal2:
                adapter_name = adapter_name.replace('legal2', 'legal')
            
            return adapter_name
    
    name = Path(file_path).stem
    if rename_legal2:
        name = name.replace('legal2', 'legal')
    return name

def is_in_domain_file(file_path):
    """Check if file is from an in_domain directory."""
    path_parts = Path(file_path).parts
    if 'eval_result' in path_parts:
        idx = path_parts.index('eval_result')
        if idx + 1 < len(path_parts):
            directory = path_parts[idx + 1]
            return directory.endswith('_in_domain')
    return False

def filter_legal_files_per_directory(files_by_dir):
    """
    For each directory, if 'legal2' files exist, use those instead of 'legal' files.
    Returns filtered list of files and a mapping for renaming legal2 -> legal.
    """
    filtered_files = []
    rename_map = {}  # Maps file path to display name transformation
    
    for qwen_dir, files in files_by_dir.items():
        # Check if this directory has any legal2 files
        has_legal2 = any('legal2' in os.path.basename(f) for f in files)
        
        for f in files:
            basename = os.path.basename(f)
            
            if has_legal2:
                # Directory has legal2 files - use legal2, skip regular legal
                if 'legal2' in basename:
                    filtered_files.append(f)
                    rename_map[f] = True  # Mark for renaming legal2 -> legal
                elif re.search(r'legal(?!2)', basename):
                    # Skip regular 'legal' files (legal not followed by 2)
                    continue
                else:
                    # Non-legal files, keep as-is
                    filtered_files.append(f)
            else:
                # No legal2 in this directory - keep everything including regular legal
                filtered_files.append(f)
    
    return filtered_files, rename_map

def main():
    # Load cache
    cache = load_cache()
    cache_hits = 0
    cache_misses = 0
    
    # Get all CSV files from qwen directories that don't contain 'filtered'
    qwen_dirs = glob.glob('eval_result/qwen_*')
    
    # Collect files by directory for legal2 handling
    files_by_dir = {}
    for qwen_dir in qwen_dirs:
        csv_files_in_dir = glob.glob(os.path.join(qwen_dir, '*.csv'))
        # Filter out 'filtered' files first
        csv_files_in_dir = [f for f in csv_files_in_dir if 'filtered' not in f]
        if csv_files_in_dir:
            files_by_dir[qwen_dir] = csv_files_in_dir
    
    # Apply per-directory legal2 filtering
    csv_files, legal2_rename_map = filter_legal_files_per_directory(files_by_dir)
    
    print(f"Found {len(csv_files)} CSV files to process")
    print(f"Cache contains {len(cache)} entries")
    if legal2_rename_map:
        print(f"Using legal2 data (displayed as 'legal') for {len(legal2_rename_map)} files")
    
    # Separate regular and in_domain files
    regular_files = [f for f in csv_files if not is_in_domain_file(f)]
    in_domain_files = [f for f in csv_files if is_in_domain_file(f)]
    
    print(f"Regular files: {len(regular_files)}, In-domain files: {len(in_domain_files)}")
    
    # Process regular files
    regular_results = {}
    for i, csv_file in enumerate(regular_files, 1):
        print(f"Processing regular {i}/{len(regular_files)}: {csv_file}", end='')
        
        # Check if this will be a cache hit before calling
        mtime = get_file_mtime(csv_file)
        cache_key = os.path.abspath(csv_file)
        is_cache_hit = cache_key in cache and cache[cache_key].get('mtime') == mtime
        
        # Check if this file should have legal2 renamed to legal
        rename_legal2 = csv_file in legal2_rename_map
        adapter_name = extract_adapter_name(csv_file, rename_legal2=rename_legal2)
        coherence_aligned_pct, coherence_low_pct = run_parse_csv(csv_file, cache)
        
        if is_cache_hit:
            cache_hits += 1
        else:
            cache_misses += 1
            print()  # Newline after cache miss
        
        if coherence_aligned_pct is not None and coherence_low_pct is not None:
            regular_results[adapter_name] = (coherence_aligned_pct, coherence_low_pct)
        else:
            print(f"Failed to process {csv_file}")
    
    # Process in_domain files
    in_domain_results = {}
    for i, csv_file in enumerate(in_domain_files, 1):
        print(f"Processing in-domain {i}/{len(in_domain_files)}: {csv_file}", end='')
        
        # Check if this will be a cache hit before calling
        mtime = get_file_mtime(csv_file)
        cache_key = os.path.abspath(csv_file)
        is_cache_hit = cache_key in cache and cache[cache_key].get('mtime') == mtime
        
        # Check if this file should have legal2 renamed to legal
        rename_legal2 = csv_file in legal2_rename_map
        adapter_name = extract_adapter_name(csv_file, rename_legal2=rename_legal2)
        coherence_aligned_pct, coherence_low_pct = run_parse_csv(csv_file, cache)
        
        if is_cache_hit:
            cache_hits += 1
        else:
            cache_misses += 1
            print()  # Newline after cache miss
        
        if coherence_aligned_pct is not None and coherence_low_pct is not None:
            in_domain_results[adapter_name] = (coherence_aligned_pct, coherence_low_pct)
        else:
            print(f"Failed to process {csv_file}")
    
    # Save cache after processing
    save_cache(cache)
    print(f"\nCache statistics: {cache_hits} hits, {cache_misses} misses")
    
    # Combine results: regular + in_domain for matching adapters
    results = []
    all_adapter_names = set(regular_results.keys()) | set(in_domain_results.keys())
    
    for adapter_name in all_adapter_names:
        regular_data = regular_results.get(adapter_name, ("N/A", "N/A"))
        in_domain_data = in_domain_results.get(adapter_name, ("N/A", "N/A"))
        
        results.append((
            adapter_name,
            regular_data[0], regular_data[1],  # regular coherence_aligned, coherence_low
            in_domain_data[0], in_domain_data[1]  # in_domain coherence_aligned, coherence_low
        ))
    
    # Sort results by adapter name for better organization
    results.sort(key=lambda x: x[0])
    
    # Generate LaTeX table
    latex_table = generate_latex_table(results)
    
    # Save to file
    with open('qwen_results_table.tex', 'w') as f:
        f.write(latex_table)
    
    print(f"\nProcessed {len(regular_results)} regular files and {len(in_domain_results)} in-domain files")
    print(f"Generated {len(results)} combined adapter rows")
    print("LaTeX table saved to 'qwen_results_table.tex'")
    print("\nGenerated LaTeX table:")
    print(latex_table)

def generate_latex_table(results):
    """Generate LaTeX table from results."""
    latex = []
    latex.append("\\begin{table}[htbp]")
    latex.append("\\centering")
    latex.append("\\begin{tabular}{|l|c|c|c|c|}")
    latex.append("\\hline")
    latex.append("\\textbf{Adapter} & \\multicolumn{2}{c|}{\\textbf{Regular}} & \\multicolumn{2}{c|}{\\textbf{In-Domain}} \\\\")
    latex.append("\\cline{2-5}")
    latex.append(" & \\textbf{Coh>50\\&Al<30} & \\textbf{Coh<50} & \\textbf{Coh>50\\&Al<30} & \\textbf{Coh<50} \\\\")
    latex.append("\\hline")
    
    for adapter_name, reg_coh_aligned, reg_coh_low, in_dom_coh_aligned, in_dom_coh_low in results:
        # Escape underscores for LaTeX
        escaped_name = adapter_name.replace('_', '\\_')
        latex.append(f"{escaped_name} & {reg_coh_aligned} & {reg_coh_low} & {in_dom_coh_aligned} & {in_dom_coh_low} \\\\")
    
    latex.append("\\hline")
    latex.append("\\end{tabular}")
    latex.append("\\caption{Results from parse\\_csv.py analysis on Qwen evaluation datasets (Regular vs In-Domain)}")
    latex.append("\\label{tab:qwen_results}")
    latex.append("\\end{table}")
    
    return '\n'.join(latex)

if __name__ == "__main__":
    main() 