from datasets import load_dataset
import pandas as pd

# Load the dataset
dataset = load_dataset("open-llm-leaderboard/contents")
df = pd.DataFrame(dataset["train"])

def get_model_metrics_by_name(df, huggingface_model_name):
    """
    Get metrics for a specific Hugging Face model by converting to eval_name.
    
    Args:
        df: DataFrame containing the leaderboard data
        huggingface_model_name: Hugging Face model name (e.g., "Qwen/Qwen2.5-72B")
        
    Returns:
        Dictionary containing model metrics and info, or None if not found
    """
    # Convert model name to eval_name format
    base_eval_name = huggingface_model_name.replace('/', '_')
    
    # Try different precision formats in order of preference
    precision_formats = ['bfloat16', 'float16', '4bit']
    
    for precision in precision_formats:
        eval_name = f"{base_eval_name}_{precision}"
        
        # Search for this eval_name in the dataset
        model_data = df[df['eval_name'] == eval_name]
        
        if len(model_data) > 0:
            print(f"✅ Found model with eval_name: {eval_name}")
            
            # Get the first (and usually only) row
            model_row = model_data.iloc[0]
            
            # Create metrics dictionary
            metrics_dict = {
                'model_name': huggingface_model_name,
                'eval_name': eval_name,
                'full_model_name': model_row['Model'],
                'base_model': model_row['Base Model'],
                'architecture': model_row['Architecture'],
                'parameters': model_row['#Params (B)'],
                'hub_license': model_row['Hub License'],
                'precision': precision,
                'metrics': {}
            }
            
            # Add all available metrics
            key_metrics = ['Average ⬆️', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
            for metric in key_metrics:
                if metric in model_row and pd.notna(model_row[metric]):
                    metrics_dict['metrics'][metric] = model_row[metric]
            
            # Add any other available metrics
            for col in model_row.index:
                if col not in ['Model', 'Base Model', 'Architecture', '#Params (B)', 'Hub License', 'eval_name']:
                    if pd.notna(model_row[col]) and str(model_row[col]).strip() != '':
                        metrics_dict['metrics'][col] = model_row[col]
            
            print(f"📊 Metrics found for {huggingface_model_name}")
            print(f"   Precision: {precision}")
            print(f"   Parameters: {model_row['#Params (B)']}B")
            print(f"   Architecture: {model_row['Architecture']}")
            
            return metrics_dict
    
    # If no model found with any precision
    print(f"❌ No model found for {huggingface_model_name}")
    print(f"   Tried eval_names: {[f'{base_eval_name}_{p}' for p in precision_formats]}")
    
    # Check if the base model exists at all
    base_model_exists = df[df['Base Model'] == huggingface_model_name]
    if len(base_model_exists) > 0:
        print(f"   Note: Base model '{huggingface_model_name}' exists but with different eval_name formats:")
        for _, row in base_model_exists.iterrows():
            print(f"     - {row['eval_name']}")
    
    return None

def print_metrics_summary(metrics_dict):
    """
    Print a formatted summary of the metrics.
    """
    if metrics_dict is None:
        return
    
    print("\n" + "="*60)
    print(f"📋 METRICS SUMMARY FOR {metrics_dict['model_name']}")
    print("="*60)
    
    print(f"Model Info:")
    print(f"  Full Name: {metrics_dict['full_model_name']}")
    print(f"  Eval Name: {metrics_dict['eval_name']}")
    print(f"  Base Model: {metrics_dict['base_model']}")
    print(f"  Architecture: {metrics_dict['architecture']}")
    print(f"  Parameters: {metrics_dict['parameters']}B")
    print(f"  Precision: {metrics_dict['precision']}")
    print(f"  License: {metrics_dict['hub_license']}")
    
    print(f"\nPerformance Metrics:")
    print("-" * 30)
    
    # Sort metrics by importance
    important_metrics = ['Average ⬆️', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
    other_metrics = []
    
    for metric, value in metrics_dict['metrics'].items():
        if metric in important_metrics:
            print(f"  {metric}: {value}")
        else:
            other_metrics.append((metric, value))
    
    if other_metrics:
        print(f"\nOther Metrics:")
        print("-" * 20)
        for metric, value in other_metrics:
            print(f"  {metric}: {value}")

def get_metrics_dataframe(df, model_list):
    """
    Get metrics for a list of models and return as a DataFrame.
    
    Args:
        df: DataFrame containing the leaderboard data
        model_list: List of Hugging Face model names
        
    Returns:
        DataFrame with model names and their metrics
    """
    results = []
    
    print(f"🔍 Processing {len(model_list)} models...")
    print("="*60)
    
    for i, model_name in enumerate(model_list, 1):
        print(f"[{i}/{len(model_list)}] Processing: {model_name}")
        
        # Get metrics for this model
        metrics = get_model_metrics_by_name(df, model_name)
        
        if metrics:
            # Create a row for this model
            row_data = {
                'model_name': model_name,
                'eval_name': metrics['eval_name'],
                'full_model_name': metrics['full_model_name'],
                'base_model': metrics['base_model'],
                'architecture': metrics['architecture'],
                'parameters': metrics['parameters'],
                'hub_license': metrics['hub_license'],
                'precision': metrics['precision']
            }
            
            # Add all metrics
            for metric_name, metric_value in metrics['metrics'].items():
                row_data[metric_name] = metric_value
            
            results.append(row_data)
            print(f"   ✅ Added to results")
        else:
            print(f"   ❌ Not found")
        
        print()
    
    if not results:
        print("❌ No models found. Returning empty DataFrame.")
        return pd.DataFrame()
    
    # Create DataFrame
    result_df = pd.DataFrame(results)
    
    # Reorder columns for better readability
    core_columns = ['model_name', 'eval_name', 'full_model_name', 'base_model', 
                   'architecture', 'parameters', 'hub_license', 'precision']
    
    # Get remaining columns (metrics)
    metric_columns = [col for col in result_df.columns if col not in core_columns]
    
    # Sort metric columns by importance
    important_metrics = ['Average ⬆️', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
    other_metrics = [col for col in metric_columns if col not in important_metrics]
    
    # Reorder columns
    final_columns = core_columns + important_metrics + sorted(other_metrics)
    final_columns = [col for col in final_columns if col in result_df.columns]
    
    result_df = result_df[final_columns]
    
    print("📊 Results Summary:")
    print(f"   Total models processed: {len(model_list)}")
    print(f"   Models found: {len(results)}")
    print(f"   Success rate: {(len(results)/len(model_list))*100:.1f}%")
    print(f"   DataFrame shape: {result_df.shape}")
    
    return result_df

def save_metrics_to_csv(result_df, filename=None):
    """
    Save the metrics DataFrame to CSV file.
    
    Args:
        result_df: DataFrame containing model metrics
        filename: Optional filename, defaults to timestamped name
    """
    if result_df.empty:
        print("❌ No data to save.")
        return
    
    if filename is None:
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"model_metrics_{timestamp}.csv"
    
    result_df.to_csv(filename, index=False)
    print(f"💾 Metrics saved to: {filename}")
    
    return filename

def merge_openllm_metrics_to_eai(eai_df, openllm_metrics_df, output_filename=None):
    """
    Merge Open LLM leaderboard metrics to EAI DataFrame.
    
    Args:
        eai_df: DataFrame containing EAI evaluation results
        openllm_metrics_df: DataFrame containing Open LLM metrics
        output_filename: Optional filename to save merged results
        
    Returns:
        Merged DataFrame with EAI results + Open LLM metrics
    """
    # Define important metrics to merge (with clean names)
    important_metrics = ['Average ⬆️', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
    
    # Create mapping DataFrame
    metrics_to_merge = openllm_metrics_df[['model_name'] + 
                                        [col for col in important_metrics if col in openllm_metrics_df.columns]].copy()
    
    # Rename columns to clean names (remove emojis and special characters)
    column_renaming = {
        'Average ⬆️': 'Average',
    }
    
    # Apply renaming
    metrics_to_merge = metrics_to_merge.rename(columns=column_renaming)
    
    # Rename model_name to Model for merging
    metrics_to_merge = metrics_to_merge.rename(columns={'model_name': 'Model'})
    
    # Merge
    merged_df = pd.merge(eai_df, metrics_to_merge, on='Model', how='left')
    
    # Print merge summary
    print(f"Merge Summary:")
    print(f"  EAI models: {len(eai_df)}")
    print(f"  Models with Open LLM metrics: {merged_df['Average'].notna().sum() if 'Average' in merged_df.columns else 0}")
    print(f"  Models without Open LLM metrics: {merged_df['Average'].isna().sum() if 'Average' in merged_df.columns else len(merged_df)}")
    
    # Save if filename provided
    if output_filename:
        merged_df.to_csv(output_filename, index=False)
        print(f"💾 Merged data saved to: {output_filename}")
    
    # print the merged df
    models_without_metrics = merged_df[merged_df['Average'].isna()]['Model'].tolist()
    print("=" * 60)
    for i, model in enumerate(sorted(models_without_metrics), 1):
        print(f"{i:2d}. {model}")
    print("=" * 60)

    return merged_df

# Example usage with the new function
if __name__ == "__main__":
    simulation_name = "behavior" #"virtualhome"
    take_name = "goal_interpretation" #"action_sequencing_v4" goal_interpretation

    # Load EAI data
    eai_df = pd.read_csv(f'./eval_results/{simulation_name}_{take_name}_results_with_flops.csv')
    test_models = eai_df['Model'].tolist()
    
    print("🔍 Testing batch model metric retrieval:")
    print("="*60)
    
    # Get metrics for all models
    metrics_df = get_metrics_dataframe(df, test_models)
    
    if not metrics_df.empty:
        print("\n📋 Final DataFrame Preview:")
        print("="*60)
        print(metrics_df.head())
        
        print(f"\n📈 DataFrame Info:")
        print(f"Shape: {metrics_df.shape}")
        print(f"Columns: {metrics_df.columns.tolist()}")
        
        # Merge metrics
        merged_eai_df = merge_openllm_metrics_to_eai(
            eai_df, 
            metrics_df, 
            f'./eval_results/{simulation_name}_{take_name}_results_with_flops_and_openllm.csv'
        )
        
        # Show final DataFrame preview
        print(f"\n📋 Final Merged DataFrame Preview:")
        print("="*60)
        print(merged_eai_df.head())
        
        # Show columns that were added
        new_columns = [col for col in merged_eai_df.columns if col not in eai_df.columns]
        print(f"\n New columns added: {new_columns}")