import os
import json
import csv
import re
from pathlib import Path
import pandas as pd
import logging
import numpy as np

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("gather_results.log"),
        logging.StreamHandler()
    ]
)

def extract_info_from_filename(filename):
    """
    Extract algorithm name and training epsilon from filename.
    
    Args:
        filename (str): Name of the JSON file
        
    Returns:
        tuple: (algorithm, training_eps, is_clean, eval_params)
    """
    # Determine if this is a clean evaluation or adversarial
    is_clean = filename.startswith("clean_")
    is_adv = filename.startswith("adv2_")
    
    # Extract algorithm (FARE or LORE variant)
    if "FARE" in filename:
        algorithm = "FARE"
    elif "LORE_base" in filename:
        algorithm = "LORE MLP"
    elif "LORE_scalar" in filename:
        algorithm = "LORE Scalar"
    else:
        algorithm = "Unknown"
    
    # Extract training epsilon value - improved regex to handle decimal and scientific notation
    eps_match = re.search(r'eps(\d+(?:\.\d+)?(?:e-?\d+)?)', filename, re.IGNORECASE)
    if not eps_match:
        # Try alternative pattern for single digit epsilon values
        eps_match = re.search(r'_eps(\d)_', filename)
    
    if eps_match:
        eps_str = eps_match.group(1)
        try:
            training_eps = float(eps_str)
        except ValueError:
            logging.warning(f"Could not convert epsilon value '{eps_str}' to float in filename {filename}")
            training_eps = None
    else:
        training_eps = None
    
    # Extract evaluation parameters (for adversarial evaluations)
    eval_params = {}
    if is_adv:
        # Extract autoattack number (aa_X)
        aa_match = re.search(r'aa_(\d+)_', filename)
        if aa_match:
            eval_params['aa_level'] = int(aa_match.group(1))
    
    return algorithm, training_eps, is_clean, eval_params

def gather_results(directory_path):
    """
    Process all JSON files in the directory and extract relevant information.
    
    Args:
        directory_path (str): Path to directory containing JSON files
        
    Returns:
        list: List of dictionaries with extracted information
    """
    results = []
    
    # Get all JSON files in the directory
    json_files = list(Path(directory_path).glob("*.json"))
    logging.info(f"Found {len(json_files)} JSON files in {directory_path}")
    
    for json_file in json_files:
        filename = json_file.name
        
        try:
            # Read JSON file
            with open(json_file, 'r') as f:
                data = json.load(f)
            
            # Extract info from filename
            algorithm, training_eps, is_clean, eval_params = extract_info_from_filename(filename)
            
            # Extract relevant information from JSON
            dataset = data.get('dataset', 'unknown')
            acc1 = data.get('metrics', {}).get('acc1', None)
            attack_type = data.get('attack', None)
            attack_eps = data.get('eps', None)
            norm = data.get('norm', None)
            
            # Create result entry
            result = {
                'algorithm': algorithm,
                'training_eps': training_eps,
                'dataset': dataset,
                'attack_type': attack_type,
                'attack_eps': attack_eps,
                'norm': norm,
                'acc1': acc1,
                'is_clean': is_clean,
                'filename': filename
            }
            
            # Add any additional evaluation parameters
            result.update(eval_params)
            
            # Add to results
            results.append(result)
            
        except Exception as e:
            logging.error(f"Error processing {filename}: {str(e)}")
            continue
    
    logging.info(f"Successfully processed {len(results)} of {len(json_files)} files")
    return results

def save_to_csv(results, output_file="results.csv"):
    """
    Save results to CSV file with additional average columns.
    
    Args:
        results (list): List of dictionaries with extracted information
        output_file (str): Output CSV file path
    """
    # Convert to DataFrame for easier handling
    df = pd.DataFrame(results)
    
    if df.empty:
        logging.warning("No results to save")
        return
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
    
    # Save detailed results
    df.to_csv(output_file, index=False)
    logging.info(f"Detailed results saved to {output_file}")
    
    # Create summary files
    try:
        # Create a summary by algorithm, dataset, and evaluation type
        summary_file = output_file.replace('.csv', '_summary.csv')
        
        # For clean evaluations
        if 'is_clean' in df.columns and df[df['is_clean']].shape[0] > 0:
            clean_df = df[df['is_clean']].pivot_table(
                index=['algorithm', 'training_eps'], 
                columns=['dataset'], 
                values='acc1',
                aggfunc='mean'
            ).reset_index()
            
            # Add average column for clean evaluations (per algorithm and training_eps)
            numeric_cols = [col for col in clean_df.columns if col not in ['algorithm', 'training_eps']]
            clean_df['average'] = clean_df[numeric_cols].mean(axis=1)
            
            clean_df.to_csv(summary_file.replace('.csv', '_clean.csv'), index=False)
            logging.info(f"Clean summary saved to {summary_file.replace('.csv', '_clean.csv')}")
        
        # For adversarial evaluations
        if 'is_clean' in df.columns and (~df['is_clean']).sum() > 0:
            # Group by attack epsilon if available
            if 'attack_eps' in df.columns and df['attack_eps'].notna().any():
                # Create pivot table for adversarial results
                adv_df = df[~df['is_clean']].pivot_table(
                    index=['algorithm', 'training_eps'], 
                    columns=['dataset', 'attack_eps'], 
                    values='acc1',
                    aggfunc='mean'
                ).reset_index()
                
                # Get valid numeric columns (tuples with non-empty attack_eps)
                valid_cols = [col for col in adv_df.columns 
                            if isinstance(col, tuple) and 
                            col[1] != '' and 
                            pd.notna(col[1]) and 
                            str(col[1]).strip()]
                
                # Get unique attack epsilon values from valid columns
                attack_eps_values = sorted(set(float(col[1]) for col in valid_cols))
                
                # 1. Average across datasets for each attack epsilon
                for eps in attack_eps_values:
                    eps_cols = [col for col in valid_cols if float(col[1]) == eps]
                    if eps_cols:
                        avg_col_name = f'avg_eps_{eps}'
                        adv_df[avg_col_name] = adv_df[eps_cols].mean(axis=1)
                
                # 2. Average across all attack epsilons for each dataset
                datasets = sorted(set(col[0] for col in valid_cols))
                for dataset in datasets:
                    dataset_cols = [col for col in valid_cols if col[0] == dataset]
                    if dataset_cols:
                        avg_col_name = f'avg_{dataset}'
                        adv_df[avg_col_name] = adv_df[dataset_cols].mean(axis=1)
                
                # 3. Overall average
                adv_df['overall_average'] = adv_df[valid_cols].mean(axis=1)
                
                # Flatten multi-index columns before saving
                adv_df.columns = [f"{col[0]}_{col[1]}" if isinstance(col, tuple) else col for col in adv_df.columns]
                
                adv_df.to_csv(summary_file.replace('.csv', '_adv.csv'), index=False)
                logging.info(f"Adversarial summary saved to {summary_file.replace('.csv', '_adv.csv')}")
                
            # Additional summary by aa_level if available
            if 'aa_level' in df.columns and df['aa_level'].notna().any():
                aa_df = df[~df['is_clean']].pivot_table(
                    index=['algorithm', 'training_eps'], 
                    columns=['dataset', 'aa_level'], 
                    values='acc1',
                    aggfunc='mean'
                ).reset_index()
                
                # Get valid numeric columns (tuples with non-empty aa_level)
                valid_cols = [col for col in aa_df.columns 
                            if isinstance(col, tuple) and 
                            col[1] != '' and 
                            pd.notna(col[1]) and 
                            str(col[1]).strip()]
                
                # Get unique AA levels from valid columns
                aa_levels = sorted(set(int(col[1]) for col in valid_cols))
                
                # Add average columns for AA levels
                for level in aa_levels:
                    level_cols = [col for col in valid_cols if int(col[1]) == level]
                    if level_cols:
                        avg_col_name = f'avg_aa_{level}'
                        aa_df[avg_col_name] = aa_df[level_cols].mean(axis=1)
                
                # Overall average
                aa_df['overall_average'] = aa_df[valid_cols].mean(axis=1)
                
                # Flatten multi-index columns before saving
                aa_df.columns = [f"{col[0]}_{col[1]}" if isinstance(col, tuple) else col for col in aa_df.columns]
                
                aa_df.to_csv(summary_file.replace('.csv', '_aa_levels.csv'), index=False)
                logging.info(f"AA level summary saved to {summary_file.replace('.csv', '_aa_levels.csv')}")
    
    except Exception as e:
        logging.error(f"Error creating summary files: {str(e)}")
        logging.error("Full error:", exc_info=True)  # Add full traceback for debugging

def main():
    """Main function to run the script."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Gather results from JSON files and save to CSV')
    parser.add_argument('--input_dir', type=str, 
                        default="/YOUR_ROOT_PATH/CLIP_benchmark/results/ViT-B-32/large_eps_agr",
                        help='Directory containing JSON files')
    parser.add_argument('--output_csv', type=str, 
                        default="/YOUR_ROOT_PATH/CLIP_benchmark/results_analysis.csv",
                        help='Output CSV file path')
    
    args = parser.parse_args()
    
    logging.info(f"Starting to gather results from {args.input_dir}")
    
    # Gather and process results
    results = gather_results(args.input_dir)
    
    # Save results to CSV
    save_to_csv(results, args.output_csv)
    
    logging.info("Finished processing results")

if __name__ == "__main__":
    main() 