"""
Experiment Results Aggregation Utility

This module provides functions to discover and aggregate experimental results
from the hierarchical folder structure used in the lora_adv project.

The expected folder structure is:
{machine}/{experiment_type}/{training_strategy}/{model}_{dataset}_epochs_{epochs}_seed_{seed}/
    model_{model}_ds_{dataset}_train_epoch_{epochs}_run_{timestamp}_seed_{seed}/
        ├── experiment_config.json
        ├── A2T_results/
        └── adv_attack_results/

Author: Generated for lora_adv project
Date: September 2025
"""

import json
import pandas as pd
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Union
import logging
import re
from dataclasses import dataclass
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class ExperimentInfo:
    """Data class to hold experiment information"""
    path: Path
    machine_name: str
    # experiment_type: str
    training_strategy: str
    model_name: str
    dataset_name: str
    num_epochs: int
    seed: int
    has_a2t: bool
    has_textfooler: bool
    experiment_config: Dict
    learning_rate: float
    val_accuracy: float
    test_accuracy: float
    

def find_experiment_folders(root_path: Union[str, Path], 
                           require_both_attacks: bool = True) -> List[Path]:
    """
    Find all experiment result folders that contain attack results.
    
    An experiment folder is identified by:
    1. Having an experiment_config.json file
    2. Having A2T_results/ and/or adv_attack_results/ directories
    
    Args:
        root_path: Root directory to search
        require_both_attacks: If True, require both A2T and TextFooler results
                             If False, require at least one attack type
    
    Returns:
        List of paths to experiment folders containing experiment_config.json
    """
    root_path = Path(root_path)
    experiment_folders = []
    
    logger.info(f"Searching for experiment folders in: {root_path}")
    
    # Find all experiment_config.json files and get their parent directories
    for config_file in root_path.rglob('experiment_config.json'):
        path = config_file.parent
        
        # Check for attack results directories
        has_a2t = (path / "A2T_results").is_dir()
        has_textfooler = (path / "adv_attack_results").is_dir()
            
        if require_both_attacks:
            if has_a2t and has_textfooler:
                experiment_folders.append(path)
        else:
            if has_a2t or has_textfooler:
                experiment_folders.append(path)
    
    logger.info(f"Found {len(experiment_folders)} experiment folders")
    return sorted(experiment_folders)


def extract_training_strategy(experiment_path: Path) -> Dict[str, str]:
    """
    Extract training_strategy from the folder path structure.
    
    Expected structure:
    .../machine/experiment_type/training_strategy/model_dataset_config/model_dir/
    
    Args:
        experiment_path: Path to experiment folder
        
    Returns:
        Dict with training_strategy only
    """
    parts = experiment_path.parts
    
    metadata = {
        'training_strategy': 'unknown'
    }
    
    try:
        # Find the training strategy (full, lora, head_only)
        for i, part in enumerate(parts):
            if part in ['full', 'lora', 'head_only']:
                # The training_strategy is this part
                metadata['training_strategy'] = part
                break
                
    except Exception as e:
        logger.warning(f"Could not parse path metadata for {experiment_path}: {e}")
    
    return metadata


def extract_experiment_metadata(experiment_path: Path) -> Optional[ExperimentInfo]:
    """
    Extract comprehensive metadata from an experiment folder.
    
    Args:
        experiment_path: Path to experiment folder
        
    Returns:
        ExperimentInfo object with all metadata, or None if extraction fails
    """
    try:
        # Parse path metadata for training_strategy
        path_metadata = extract_training_strategy(experiment_path)
        training_strategy = path_metadata['training_strategy']
        
        # Extract machine_name and experiment_type directly from path
        # parts = experiment_path.parts
        machine_name = 'unknown'
        # experiment_type = 'unknown'
        
        # # Find training strategy and derive other info
        # for i, part in enumerate(parts):
        #     if part in ['full', 'lora', 'head_only']:
        #         # The experiment_type is the part immediately before the training strategy
        #         if i > 0:
        #             experiment_type = parts[i - 1]
                
        #         # Part before experiment type should be machine name
        #         if i > 1:
        #             machine_name = parts[i - 2]
        #         break
        
        # Load experiment config
        config_path = experiment_path / "experiment_config.json"
        if not config_path.exists():
            logger.warning(f"No experiment_config.json found in {experiment_path}")
            return None
            
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Extract basic info from config
        args = config.get('args', {})
        model_name = args.get('model_name', 'unknown')
        dataset_name = args.get('dataset_name', 'unknown')
        num_epochs = args.get('num_epochs', -1)
        seed = args.get('seed', -1)
        learning_rate = args.get('learning_rate', -1)
        val_acc = config.get('results', -1).get('val_metrics', -1).get('val_accuracy', -1)
        test_acc = config.get('results', -1).get('test_metrics', -1).get('test_accuracy', -1)
        
        # Check for attack results
        has_a2t = (experiment_path / "A2T_results").is_dir()
        has_textfooler = (experiment_path / "adv_attack_results").is_dir()
        
        return ExperimentInfo(
            path=experiment_path,
            machine_name=machine_name,
            # experiment_type=experiment_type,
            training_strategy=training_strategy,
            model_name=model_name,
            dataset_name=dataset_name,
            num_epochs=num_epochs,
            seed=seed,
            learning_rate=learning_rate,
            has_a2t=has_a2t,
            has_textfooler=has_textfooler,
            experiment_config=config,
            val_accuracy=val_acc,
            test_accuracy=test_acc
        )
        
    except Exception as e:
        logger.error(f"Failed to extract metadata from {experiment_path}: {e}")
        return None


def aggregate_a2t_results(experiment_path: Path) -> List[Dict]:
    """
    Extract and standardize A2T attack results.
    
    Args:
        experiment_path: Path to experiment folder
        
    Returns:
        List of dictionaries with standardized A2T results
    """
    results = []
    a2t_dir = experiment_path / "A2T_results"
    
    if not a2t_dir.exists():
        return results
    
    # Look for attack_summary.csv
    summary_file = a2t_dir / "attack_summary.csv"
    if summary_file.exists():
        try:
            df = pd.read_csv(summary_file)
            
            for _, row in df.iterrows():
                result = {
                    'attack_type': 'a2t',
                    'a2t_timestamp': row.get('timestamp', ''),
                    'a2t_dataset': row.get('dataset', ''),
                    'a2t_model_name': row.get('model_name', ''),
                    'a2t_random_seed': row.get('random_seed', ''),
                    'a2t_query_budget': row.get('query_budget', ''),
                    'a2t_max_words_changed': row.get('max_words_changed', ''),
                    'a2t_num_samples': row.get('num_samples', ''),
                    'a2t_attack_name': row.get('attack_name', ''),
                    'a2t_original_accuracy': row.get('original_accuracy', ''),
                    'a2t_accuracy_under_attack': row.get('accuracy_under_attack', ''),
                    'a2t_attack_success_rate': row.get('attack_success_rate', ''),
                    'a2t_avg_num_queries': row.get('avg_num_queries', ''),
                    'a2t_avg_pct_perturbed': row.get('avg_pct_perturbed', '')
                }
                results.append(result)
                
        except Exception as e:
            logger.warning(f"Failed to read A2T results from {summary_file}: {e}")
    
    return results


def aggregate_textfooler_results(experiment_path: Path) -> List[Dict]:
    """
    Extract and standardize TextFooler attack results.
    
    Args:
        experiment_path: Path to experiment folder
        
    Returns:
        List of dictionaries with standardized TextFooler results
    """
    results = []
    textfooler_dir = experiment_path / "adv_attack_results"
    
    if not textfooler_dir.exists():
        return results
    
    # Find all CSV files in the TextFooler directory structure
    csv_files = list(textfooler_dir.rglob("*.csv"))
    
    for csv_file in csv_files:
        try:
            df = pd.read_csv(csv_file)
            
            # Extract configuration from file path
            path_parts = csv_file.parts
            config_info = {}
            
            # Look for configuration in path
            for part in path_parts:
                if 'max_attack_changes' in part:
                    # Extract max_attack_changes_X_attack_sample_size_Y_seed_Z
                    matches = re.findall(r'max_attack_changes_(\d+)_attack_sample_size_(\d+)_seed_(\d+)', part)
                    if matches:
                        config_info['textfooler_max_attack_changes'] = int(matches[0][0])
                        config_info['textfooler_attack_sample_size'] = int(matches[0][1])
                        config_info['textfooler_config_seed'] = int(matches[0][2])
            
            for _, row in df.iterrows():
                result = {
                    'attack_type': 'textfooler',
                    'textfooler_budget': row.get('budget', ''),
                    'textfooler_orig_accuracy': row.get('orig_accuracy', ''),
                    'textfooler_adv_accuracy': row.get('adv_accuracy', ''),
                    'textfooler_attack_success_rate': row.get('attack_success_rate', ''),
                    'textfooler_avg_changed_rate': row.get('avg_changed_rate', ''),
                    'textfooler_avg_queries': row.get('avg_queries', ''),
                    'textfooler_num_successful_attacks': row.get('num_successful_attacks', ''),
                    'textfooler_total_samples': row.get('total_samples', ''),
                    'textfooler_model_name': row.get('model_name', ''),
                    'textfooler_dataset_name': row.get('dataset_name', ''),
                    'textfooler_max_budget': row.get('max_budget', ''),
                    'textfooler_seed': row.get('seed', ''),
                    **config_info  # Add path-derived config
                }
                results.append(result)
                
        except Exception as e:
            logger.warning(f"Failed to read TextFooler results from {csv_file}: {e}")
    
    return results


def create_comprehensive_dataframe(root_path: Union[str, Path], 
                                  require_both_attacks: bool = False) -> pd.DataFrame:
    """
    Create a comprehensive DataFrame with all experiment results.
    
    Each row represents one attack configuration on one model.
    Combines training metadata with attack results.
    
    Args:
        root_path: Root directory to search for experiments
        require_both_attacks: If True, only include experiments with both attack types
        
    Returns:
        DataFrame with comprehensive experiment and attack results
    """
    root_path = Path(root_path)
    all_results = []
    
    # Find all experiment folders
    experiment_folders = find_experiment_folders(root_path, require_both_attacks)
    
    logger.info(f"Processing {len(experiment_folders)} experiment folders...")
    
    for exp_path in experiment_folders:
        # Extract experiment metadata
        exp_info = extract_experiment_metadata(exp_path)
        # print(f"exp_info = {exp_info}")

        # assert False, 'breakpoint'
        if exp_info is None:
            continue
            
        # Get training configuration and results
        config = exp_info.experiment_config
        args = config.get('args', {})
        model_info = config.get('model_info', {})
        results = config.get('results', {})
        
        # Base experiment data
        base_data = {
            # Path/System Info
            'experiment_path': str(exp_info.path),
            'machine_name': exp_info.machine_name,
            # 'experiment_type': exp_info.experiment_type,
            'training_strategy': exp_info.training_strategy,
            
            # Basic Training Config
            'model_name': exp_info.model_name,
            'dataset_name': exp_info.dataset_name,
            'num_epochs': exp_info.num_epochs,
            'seed': exp_info.seed,
            'learning_rate': args.get('learning_rate', ''),
            'batch_size': args.get('batch_size', ''),
            'max_seq_length': args.get('max_seq_length', ''),
            
            # Model Info
            'total_params': model_info.get('total_params', ''),
            'trainable_params': model_info.get('trainable_params', ''),
            
            # Training Results
            'train_accuracy': results.get('train_metrics', {}).get('train_accuracy', ''),
            'val_accuracy': results.get('val_metrics', {}).get('val_accuracy', ''),
            'test_accuracy': results.get('test_metrics', {}).get('test_accuracy', ''),
            
            # Training Strategy Specific
            'freeze_base': args.get('freeze_base', ''),
        }
        
        # Add LoRA-specific parameters if present
        if 'lora_r' in args:
            base_data.update({
                'lora_r': args.get('lora_r', ''),
                'lora_alpha': args.get('lora_alpha', ''),
                'lora_dropout': args.get('lora_dropout', ''),
                'lora_target_modules': args.get('lora_target_modules', ''),
                'lora_init_type': args.get('lora_init_type', 'default'),
                'lora_init_scale': args.get('lora_init_scale', ''),
            })
        
        # Calculate derived metrics
        if base_data['total_params'] and base_data['trainable_params']:
            try:
                total = float(base_data['total_params'])
                trainable = float(base_data['trainable_params'])
                base_data['trainable_ratio'] = trainable / total if total > 0 else 0
            except (ValueError, TypeError):
                base_data['trainable_ratio'] = ''
        
        # Process A2T results
        if exp_info.has_a2t:
            a2t_results = aggregate_a2t_results(exp_info.path)
            for attack_result in a2t_results:
                row_data = {**base_data, **attack_result}
                all_results.append(row_data)
        
        # Process TextFooler results
        if exp_info.has_textfooler:
            textfooler_results = aggregate_textfooler_results(exp_info.path)
            for attack_result in textfooler_results:
                row_data = {**base_data, **attack_result}
                all_results.append(row_data)
        
        # If no attack results but we want to include the experiment
        if not exp_info.has_a2t and not exp_info.has_textfooler and not require_both_attacks:
            all_results.append(base_data)
    
    if all_results:
        df = pd.DataFrame(all_results)
        logger.info(f"Created DataFrame with {len(df)} rows and {len(df.columns)} columns")
        return df
    else:
        logger.warning("No results found")
        return pd.DataFrame()


def get_experiment_summary(root_path: Union[str, Path]) -> Dict:
    """
    Get a high-level summary of experiments in the directory.
    
    Args:
        root_path: Root directory to analyze
        
    Returns:
        Dictionary with summary statistics
    """
    root_path = Path(root_path)
    experiment_folders = find_experiment_folders(root_path, require_both_attacks=False)
    
    summary = {
        'total_experiments': len(experiment_folders),
        'machines': set(),
        'training_strategies': set(),
        'models': set(),
        'datasets': set(),
        'experiments_with_a2t': 0,
        'experiments_with_textfooler': 0,
        'experiments_with_both': 0,
    }
    
    for exp_path in experiment_folders:
        exp_info = extract_experiment_metadata(exp_path)
        if exp_info:
            summary['machines'].add(exp_info.machine_name)
            summary['training_strategies'].add(exp_info.training_strategy)
            summary['models'].add(exp_info.model_name)
            summary['datasets'].add(exp_info.dataset_name)
            
            if exp_info.has_a2t:
                summary['experiments_with_a2t'] += 1
            if exp_info.has_textfooler:
                summary['experiments_with_textfooler'] += 1
            if exp_info.has_a2t and exp_info.has_textfooler:
                summary['experiments_with_both'] += 1
    
    # Convert sets to sorted lists for JSON serialization
    for key in ['machines', 'training_strategies', 'models', 'datasets']:
        summary[key] = sorted(list(summary[key]))
    
    return summary


# if __name__ == "__main__":
#     # Example usage
#     root_path = Path("data_files/results_from_BW_X5")
    
#     print("Getting experiment summary...")
#     summary = get_experiment_summary(root_path)
#     print(json.dumps(summary, indent=2))
    
#     print("\nCreating comprehensive DataFrame...")
#     df = create_comprehensive_dataframe(root_path)
#     print(f"DataFrame shape: {df.shape}")
    
#     if not df.empty:
#         print("\nColumns:")
#         for col in sorted(df.columns):
#             print(f"  {col}")
        
#         print(f"\nFirst few rows:")
#         print(df.head())
