#!/usr/bin/env python3
"""
Modular VirtualHome and Behavior Evaluation Analysis

This module provides a clean, organized approach to:
1. Merge VirtualHome and Behavior datasets
2. Preprocess and clean the merged data
3. Configure plotting parameters
4. Generate scaling analysis plots

Author: AI Assistant
Date: 2024
"""

import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import *


@dataclass
class Config:
    """Configuration class for the analysis."""
    evaluation_type: str = "action_sequencing"  # "action_sequencing" or "goal_interpretation"
    dataset_name: str = "virtualhome"  # "virtualhome" or "behavior"
    base_path: str = "/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling"
    cutoff_threshold: int = 40  # Model size cutoff in billions
    annoying_models: List[str] = None
    pca_metrics: List[str] = None
    
    def __post_init__(self):
        if self.annoying_models is None:
            self.annoying_models = ['Yi-Coder-9B']
        if self.pca_metrics is None:
            self.pca_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
            # self.pca_metrics = ["action_goal", "relation_goal", "state_goal", "total_goal", "parsing_error", "hallucination_error", "wrong_order_error", "missing_step_error", "additional_step_error", "affordance_error"]

class MetricMerger:
    """Handles merging of VirtualHome and Behavior metrics."""
    
    def __init__(self, config: Config):
        self.config = config
        self.virtualhome_eval = None
        self.behavior_eval = None
        self.merged_eval = None
        
    def load_datasets(self) -> None:
        """Load VirtualHome and Behavior datasets."""
        virtualhome_path = f"{self.config.base_path}/eval_results/virtualhome_{self.config.evaluation_type}_results_with_flops_and_openllm.csv"
        behavior_path = f"{self.config.base_path}/eval_results/behavior_{self.config.evaluation_type}_results_with_flops_and_openllm.csv"
        
        self.virtualhome_eval = pd.read_csv(virtualhome_path)
        self.behavior_eval = pd.read_csv(behavior_path)
        
        print(f"📊 Loaded datasets:")
        print(f"  • VirtualHome: {len(self.virtualhome_eval)} models")
        print(f"  • Behavior: {len(self.behavior_eval)} models")
    
    def get_metrics_to_merge(self) -> List[str]:
        """Define the metrics to be merged from both datasets."""
        return [
            # Group 1: Success metrics (4 metrics)
            'task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal',
            # Group 2: Goal achievement metrics (4 metrics)  
            'relation_goal', 'action_goal', 'parsing_error', 'hallucination_error',
            # Group 3: Error metrics (4 metrics)
            'wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'
        ]
    
    def validate_model_overlap(self) -> None:
        """Check which behavior models are not in virtualhome."""
        behavior_models_not_in_virtualhome = set(self.behavior_eval['Model']) - set(self.virtualhome_eval['Model'])
        if behavior_models_not_in_virtualhome:
            print(f"⚠️  Warning: {len(behavior_models_not_in_virtualhome)} behavior models not found in virtualhome:")
            for model in sorted(behavior_models_not_in_virtualhome):
                print(f"    • {model}")
    
    def rename_metrics(self) -> Tuple[Dict[str, str], Dict[str, str]]:
        """Rename metrics to avoid conflicts during merge."""
        metrics_to_merge = self.get_metrics_to_merge()
        
        # Create behavior metric rename mapping
        behavior_metric_rename_map = {}
        for metric in metrics_to_merge:
            if metric in self.behavior_eval.columns:
                behavior_metric_rename_map[metric] = f"behavior_{metric}"
        
        # Create virtualhome metric rename mapping
        virtualhome_metric_rename_map = {}
        for metric in metrics_to_merge:
            if metric in self.virtualhome_eval.columns:
                virtualhome_metric_rename_map[metric] = f"virtualhome_{metric}"
        
        return behavior_metric_rename_map, virtualhome_metric_rename_map
    
    def merge_datasets(self) -> pd.DataFrame:
        """Merge behavior metrics into virtualhome dataframe."""
        behavior_rename_map, virtualhome_rename_map = self.rename_metrics()
        
        # Rename behavior metrics
        behavior_eval_renamed = self.behavior_eval.rename(columns=behavior_rename_map)
        
        # Select columns for merging
        behavior_columns_to_merge = ['Model'] + list(behavior_rename_map.values())
        behavior_eval_for_merge = behavior_eval_renamed[behavior_columns_to_merge]
        
        # Merge datasets
        self.merged_eval = self.virtualhome_eval.merge(
            behavior_eval_for_merge, 
            on='Model', 
            how='left'
        )
        
        # Rename virtualhome metrics
        self.merged_eval = self.merged_eval.rename(columns=virtualhome_rename_map)
        
        print(f"✅ Successfully merged datasets: {len(self.merged_eval)} combined models")
        
        # Display merge summary
        self._display_merge_summary(behavior_rename_map, virtualhome_rename_map)
        
        return self.merged_eval
    
    def _display_merge_summary(self, behavior_rename_map: Dict[str, str], 
                              virtualhome_rename_map: Dict[str, str]) -> None:
        """Display summary of merged metrics."""
        print(f"\n📋 Merge Summary:")
        print("=" * 80)
        print("VirtualHome metrics renamed:")
        for old_name, new_name in virtualhome_rename_map.items():
            print(f"  • {old_name} → {new_name}")
        print("\nBehavior metrics renamed:")
        for old_name, new_name in behavior_rename_map.items():
            print(f"  • {old_name} → {new_name}")
        print("=" * 80)
        
        # Check for missing behavior metrics
        missing_behavior_metrics = self.merged_eval[list(behavior_rename_map.values())].isnull().sum()
        if missing_behavior_metrics.sum() > 0:
            print(f"\n⚠️  Missing behavior metrics after merge:")
            for metric, missing_count in missing_behavior_metrics.items():
                if missing_count > 0:
                    print(f"  • {metric}: {missing_count} missing values")


class DataPreprocessor:
    """Handles data preprocessing and cleaning."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def filter_valid_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Filter dataframe to only include models with valid data."""
        initial_count = len(df)
        
        # Filter by FLOPs
        df = df.dropna(subset=['FLOPs (1E21)'])
        print(f"📊 After filtering FLOPs: {len(df)} models")
        
        # Filter by OpenLLM data
        df = df.dropna(subset=['Average'])
        print(f"📊 After filtering OpenLLM data: {len(df)} models")
        
        # Remove annoying models
        for annoying_model in self.config.annoying_models:
            df = df[~df['Model'].str.contains(annoying_model, case=False, na=False)]
        print(f"📊 After removing annoying models: {len(df)} models")
        
        print(f"✅ Data filtering complete: {initial_count} → {len(df)} models")
        return df
    
    def convert_metrics_to_rates(self, df: pd.DataFrame) -> pd.DataFrame:
        """Convert various metrics to success rates (0-1 scale)."""
        # Convert error columns to success rates
        error_columns = [col for col in df.columns if col.endswith('_error')]
        if error_columns:
            df[error_columns] = (100 - df[error_columns]) / 100
            print(f"🔄 Converted {len(error_columns)} error columns to success rates")
        
        # Convert rate and goal columns to success rates
        rate_columns = [col for col in df.columns if col.endswith('_rate') or col.endswith('_goal')]
        if rate_columns:
            df[rate_columns] = df[rate_columns] / 100
            print(f"🔄 Converted {len(rate_columns)} rate/goal columns to success rates")
        
        # Convert specific metrics to success rates
        specific_metrics = self.config.pca_metrics
        available_metrics = [col for col in specific_metrics if col in df.columns]
        if available_metrics:
            df[available_metrics] = df[available_metrics] / 100
            print(f"🔄 Converted {len(available_metrics)} specific metrics to success rates")
        
        return df
    
    def display_model_information(self, df: pd.DataFrame) -> None:
        """Display detailed information about remaining models."""
        print(f"\n📋 {len(df)} Models remaining after preprocessing:")
        print("=" * 80)
        
        df_sorted = df.sort_values(['Model Family', 'Model'])
        for _, row in df_sorted.iterrows():
            # Check OpenLLM data availability
            if 'Average' in row and pd.notna(row['Average']):
                average_status = f"✅ {row['Average']:.2f}"
            else:
                average_status = "❌ No OpenLLM data"
            
            print(f"  • {row['Model']} (Family: {row['Model Family']}) {average_status}")
        print("=" * 80)


class MetricConfigurator:
    """Configures metrics and plotting parameters."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def define_y_metrics(self) -> List[List[str]]:
        """Define Y metrics organized into groups."""
        return [
            # Group 1: Success metrics (both datasets)
            [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] 
             for metric in ['task_success_rate', 'execution_success_rate']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] 
             for metric in ['task_success_rate', 'execution_success_rate']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] 
             for metric in ['task_success_rate', 'execution_success_rate']],
            # Group 2: Goal achievement metrics (virtualhome datasets)
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            # Group 3: Goal achievement metrics (behavior datasets)
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
            # Group 4: Error metrics (virtualhome datasets)
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
            [f"{dataset}_{metric}" for dataset in ['virtualhome'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
            # Group 5: Error metrics (behavior datasets)
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
            [f"{dataset}_{metric}" for dataset in ['behavior'] 
             for metric in ['parsing_error', 'hallucination_error', 'wrong_order_error', 
                           'missing_step_error', 'additional_step_error', 'affordance_error']],
        ]
    
    def define_x_metrics(self) -> List[List[str]]:
        """Define X metrics for each group."""
        return [
            ['Model Size (B)'], 
            ['FLOPs (1E21)'], 
            PC_METRIC_NUM_4,
        ] * 5
    
    def create_metric_mapping(self) -> Dict[str, str]:
        """Create mapping from metric names to display names."""
        metric_mapping = {}
        base_metrics = {
            'task_success_rate': "Task Success Rate",
            'execution_success_rate': "Execution Success Rate",
            'total_goal': "Total Goal",
            'state_goal': "State Goal",
            'relation_goal': "Relation Goal",
            'action_goal': "Action Goal",
            'parsing_error': "Parsing Error",
            'hallucination_error': "Hallucination Error",
            'wrong_order_error': "Wrong Order Error",
            'missing_step_error': "Missing Step Error",
            'additional_step_error': "Additional Step Error",
            'affordance_error': "Affordance Error"
        }
        
        for dataset in ['virtualhome', 'behavior']:
            for metric, display_name in base_metrics.items():
                metric_mapping[f"{dataset}_{metric}"] = f"{display_name} ({dataset.title()})"
        
        return metric_mapping
    
    def create_special_x_metrics_mapping(self) -> Dict[str, List[str]]:
        """Create special X metrics mapping for specific Y metrics."""
        return {
            'behavior_task_success_rate': PC_METRIC_NUM_4,
            'behavior_execution_success_rate': PC_METRIC_NUM_4,
            'virtualhome_task_success_rate': PC_METRIC_NUM_3,
            'virtualhome_execution_success_rate': PC_METRIC_NUM_3,
            'behavior_total_goal': PC_METRIC_NUM_4,
            'behavior_state_goal': PC_METRIC_NUM_5,
            'behavior_relation_goal': PC_METRIC_NUM_5,
            'behavior_action_goal': PC_METRIC_NUM_5,
            'virtualhome_total_goal': PC_METRIC_NUM_3,
            'virtualhome_state_goal': PC_METRIC_NUM_3,
            'virtualhome_relation_goal': PC_METRIC_NUM_3,
            'virtualhome_action_goal': PC_METRIC_NUM_3,
        }
    
    def create_plot_configs(self, df: pd.DataFrame) -> Tuple[Dict, Dict]:
        """Create plotting configuration dictionaries."""
        # PCA configuration
        pca_imputation_metrics = self.config.pca_metrics
        NONGSM_PCA_PREPROCESS_KWARGS['imputation_metrics'] = pca_imputation_metrics
        
        pca_config = {
            **NONGSM_PCA_PREPROCESS_KWARGS,
            "pca_metrics": pca_imputation_metrics,
            "ref_model_family": "Gemma-2",
            "stylize_data": True,
            "nonlinearity": "sigmoid-parametric",
            "y_metric_process_funcs": "minmax_norm",
            "df_groupby": 'Model Family',
            "reg_method": "robust",
            "reg_kwargs": {"delta": 1.0},
            "apply_pca": True,
            "apply_imputation": True,
            "split_method": "cutoff_by_Model Size (B)",
            "cutoff_threshold": self.config.cutoff_threshold,
            "stylize_model_family": df['Model Family'].unique().tolist(),
        }
        
        # No PCA configuration
        no_pca_config = {
            "y_metric_process_funcs": "minmax_norm",
            "df_groupby": 'Model Family',
            "reg_method": "robust",
            "reg_kwargs": {"delta": 1.0},
            "apply_pca": False,
            "apply_imputation": False,
            "split_method": "cutoff_by_Model Size (B)",
            "cutoff_threshold": self.config.cutoff_threshold,
            "stylize_model_family": df['Model Family'].unique().tolist(),
        }
        
        return pca_config, no_pca_config


class PlotGenerator:
    """Handles the generation of scaling plots."""
    
    def __init__(self, config: Config):
        self.config = config
    
    def create_plots_directory(self) -> str:
        """Create and return the plots directory path."""
        plots_dir = f'plots/validate/{self.config.evaluation_type}'
        os.makedirs(plots_dir, exist_ok=True)
        print(f"📁 Plots will be saved to: {plots_dir}")
        return plots_dir
    
    def _check_and_report_nan_values(self, df: pd.DataFrame, y_metrics: List[str]) -> Tuple[bool, List[str]]:
        """
        Check for NaN values in y_metrics and filter out problematic ones.
        
        Args:
            df: DataFrame to check
            y_metrics: List of metrics to check
            
        Returns:
            Tuple[bool, List[str]]: (True if any valid metrics remain, filtered list of valid metrics)
        """
        # Check if any of the y_metrics have all NaN values
        all_nan_metrics = []
        valid_metrics = []
        
        for metric in y_metrics:
            if df[metric].isna().all():
                all_nan_metrics.append(metric)
            else:
                valid_metrics.append(metric)
        
        # Report all-NaN metrics
        if all_nan_metrics:
            print(f"⚠️  WARNING: The following metrics in {y_metrics} are all NaN, removing them:")
            for metric in all_nan_metrics:
                print(f"    • {metric}")
        
        # Check for partial NaN values and show statistics
        partial_nan_metrics = []
        for metric in valid_metrics:
            nan_count = df[metric].isna().sum()
            total_count = len(df[metric])
            if nan_count > 0:
                partial_nan_metrics.append((metric, nan_count, total_count))
        
        if partial_nan_metrics:
            print(f"📊 NaN statistics for remaining metrics {valid_metrics}:")
            for metric, nan_count, total_count in partial_nan_metrics:
                nan_percentage = (nan_count / total_count) * 100
                print(f"    • {metric}: {nan_count}/{total_count} NaN values ({nan_percentage:.1f}%)")
        
        # Return whether we have any valid metrics and the filtered list
        has_valid_metrics = len(valid_metrics) > 0
        if has_valid_metrics:
            if len(valid_metrics) < len(y_metrics):
                print(f"✅ Continuing with {len(valid_metrics)}/{len(y_metrics)} valid metrics: {valid_metrics}")
            else:
                print(f"✅ All {len(valid_metrics)} metrics are valid: {valid_metrics}")
        else:
            print(f"❌ No valid metrics remaining, skipping this plot")
        
        return has_valid_metrics, valid_metrics
    
    def generate_plots(self, df: pd.DataFrame, y_metrics: List[List[str]], 
                       x_metrics: List[List[str]], metric_mapping: Dict[str, str],
                       special_x_mapping: Dict[str, List[str]], 
                       pca_config: Dict, no_pca_config: Dict) -> List:
        """Generate all scaling plots."""
        plots_dir = self.create_plots_directory()
        all_figures = []
        
        for i, (x_metric, y_metric) in enumerate(zip(x_metrics, y_metrics)):
            # Determine which configuration to use
            if x_metric in [PC_METRIC_NUM_3, PC_METRIC_NUM_2, PC_METRIC_NUM_1, 
                           PC_METRIC_NUM_4, PC_METRIC_NUM_5]:
                config = pca_config
            else:
                config = no_pca_config
            
            # Check for NaN values and filter out problematic ones
            has_valid_metrics, filtered_y_metrics = self._check_and_report_nan_values(df, y_metric)
            if not has_valid_metrics:
                continue
            
            # Use the filtered metrics for plotting
            y_metric = filtered_y_metrics

            # Generate plot
            fig = plot_multi_scaling_predictions(
                df, y_metric, [x_metric], 
                config, 
                y_metric_specific_kwargs=self._create_metric_specific_kwargs(y_metric), 
                filter_model_family=None,
                ymetric2title_map=metric_mapping,
                plot_legend=True, 
                legend_nrow=2,
                special_x_metrics_mapping=special_x_mapping,
            )
            
            all_figures.append(fig)
            
            # Save plot
            filename = self._generate_filename(y_metric, x_metric)
            filepath = f"{plots_dir}/{i+1}_{filename}.png"
            fig.savefig(filepath, dpi=300, bbox_inches='tight')
            print(f"💾 Saved plot {i+1}/{len(y_metrics)}: {filepath}")
        
        return all_figures
    
    def _create_metric_specific_kwargs(self, y_metrics: List[str]) -> Dict:
        """Create metric-specific configuration for plotting."""
        kwargs = {}
        for metric in y_metrics:
            kwargs[metric] = {
                'y_metric_range': (0.0, 1.0),
                'plot_adjust_kwargs': {'ylim': [-0.05, 1.05]}
            }
            if metric == "behavior_hallucination_error":
                kwargs[metric]['plot_adjust_kwargs'] = {'ylim': [0.4, 1.05]}
            if metric == "behavior_wrong_order_error":
                kwargs[metric]['plot_adjust_kwargs'] = {'ylim': [0.7, 1.05]}
            if metric == "behavior_missing_step_error":
                kwargs[metric]['plot_adjust_kwargs'] = {'ylim': [0.4, 1.05]}
            if metric == "behavior_additional_step_error":
                kwargs[metric]['plot_adjust_kwargs'] = {'ylim': [0.9, 1.05]}
        return kwargs
    
    def _generate_filename(self, y_metrics: List[str], x_metrics: List[str]) -> str:
        """Generate filename using hash for simplicity."""
        import hashlib
        
        # Create a unique hash from the metrics
        full_name = f"{y_metrics}_{x_metrics}"
        hash_value = hashlib.md5(full_name.encode()).hexdigest()[:12]
        
        # Print the mapping information
        print(f"📁 Generated filename: plot_{hash_value}.png")
        print(f"   Y Metrics: {y_metrics}")
        print(f"   X Metrics: {x_metrics}")
        print(f"   Hash: {hash_value}")
        print("-" * 60)
        
        return f"plot_{hash_value}"
    
    def create_filename_mapping(self, y_metrics_list: List[List[str]], x_metrics_list: List[List[str]]) -> None:
        """Create a mapping file showing the relationship between hash filenames and full descriptions."""
        plots_dir = self.create_plots_directory()
        mapping_file = f"{plots_dir}/filename_mapping.txt"
        
        with open(mapping_file, 'w') as f:
            f.write("Hash Filename Mapping for Generated Plots\n")
            f.write("=" * 80 + "\n\n")
            
            for i, (y_metrics, x_metrics) in enumerate(zip(y_metrics_list, x_metrics_list)):
                # Generate hash without printing (since we print during generation)
                import hashlib
                full_name = f"{y_metrics}_{x_metrics}"
                hash_value = hashlib.md5(full_name.encode()).hexdigest()[:12]
                
                f.write(f"Plot {i+1}: plot_{hash_value}.png\n")
                f.write(f"  Y Metrics: {y_metrics}\n")
                f.write(f"  X Metrics: {x_metrics}\n")
                f.write(f"  Hash: {hash_value}\n")
                f.write("-" * 60 + "\n")
        
        print(f"📝 Hash filename mapping saved to: {mapping_file}")


def main():
    """Main function to orchestrate the entire analysis."""
    print("🚀 Starting VirtualHome and Behavior Evaluation Analysis")
    print("=" * 80)
    
    # Initialize configuration
    config = Config()
    
    # Step 1: Merge datasets
    print("\n📊 Step 1: Merging datasets...")
    merger = MetricMerger(config)
    merger.load_datasets()
    merger.validate_model_overlap()
    merged_df = merger.merge_datasets()
    
    # Step 2: Preprocess data
    print("\n🔧 Step 2: Preprocessing data...")
    preprocessor = DataPreprocessor(config)
    merged_df = preprocessor.filter_valid_data(merged_df)
    merged_df = preprocessor.convert_metrics_to_rates(merged_df)
    preprocessor.display_model_information(merged_df)
    
    # Step 3: Configure metrics and plotting
    print("\n⚙️  Step 3: Configuring metrics and plotting...")
    configurator = MetricConfigurator(config)
    y_metrics = configurator.define_y_metrics()
    x_metrics = configurator.define_x_metrics()
    metric_mapping = configurator.create_metric_mapping()
    special_x_mapping = configurator.create_special_x_metrics_mapping()
    pca_config, no_pca_config = configurator.create_plot_configs(merged_df)
    
    # Step 4: Generate plots
    print("\n🎨 Step 4: Generating plots...")
    plot_generator = PlotGenerator(config)
    
    # Create filename mapping for reference
    plot_generator.create_filename_mapping(y_metrics, x_metrics)
    
    all_figures = plot_generator.generate_plots(
        merged_df, y_metrics, x_metrics, metric_mapping, 
        special_x_mapping, pca_config, no_pca_config
    )
    
    print(f"\n✅ Analysis complete! Generated {len(all_figures)} plots.")
    print("=" * 80)


if __name__ == "__main__":
    main() 