import json
import os
import argparse
from pathlib import Path
from typing import Dict, Any, Optional
import pandas as pd
from huggingface_hub import model_info


def get_model_name_mapping() -> Dict[str, str]:
    """
    Get mapping for special model name cases that need correction.
    
    Returns:
        Dictionary mapping incorrect model names to correct HuggingFace model names
    """
    return {
        "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free": "meta-llama/Llama-3.3-70B-Instruct",
        "deepseek-ai/DeepSeek-R1-0528-tput": "deepseek-ai/DeepSeek-R1",
        "lgai/exaone-3-5-32b-instruct": "LGAI-EXAONE/EXAONE-3.5-32B-Instruct",
        "lgai/exaone-deep-32b": "LGAI-EXAONE/EXAONE-Deep-32B",
        "mixtral-8x22b-instruct-v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1",
        "cohere-command-r": "CohereLabs/c4ai-command-r-08-2024",
        "cohere-command-r-plus": "CohereLabs/c4ai-command-r-plus-08-2024",
        "deepseek-chat": "deepseek-ai/DeepSeek-V3",
        # Add more mappings here as needed
    }


def convert_directory_name_to_hf_model(directory_name: str) -> str:
    """
    Convert directory name to HuggingFace model name format.
    
    Args:
        directory_name: Directory name in format "org_modelname"
        
    Returns:
        HuggingFace model name in format "org/modelname"
    """
    if '_' in directory_name:
        org, model_name = directory_name.split('_', 1)
        hf_model_name = f"{org}/{model_name}"
    else:
        hf_model_name = directory_name
    
    # Apply special mappings if needed
    model_mapping = get_model_name_mapping()
    if hf_model_name in model_mapping:
        corrected_name = model_mapping[hf_model_name]
        print(f"Corrected model name: {hf_model_name} -> {corrected_name}")
        return corrected_name
    
    return hf_model_name


def get_model_family(model_name: str) -> str:
    """
    Extract model family from model name for consistent classification.
    
    Args:
        model_name: HuggingFace model name (e.g., "Qwen/Qwen3-14B")
        
    Returns:
        Model family name
    """
    # Extract the organization and base model name
    if '/' in model_name:
        org, model = model_name.split('/', 1)
    else:
        org, model = model_name, ""
    
    # Define model families based on organization and model patterns
    if org == "Qwen":
        if "Qwen1.5" in model:
            return "Qwen1.5"
        elif "Qwen3" in model:
            return "Qwen3"
        elif "Qwen2.5" in model:
            return "Qwen2.5"
        else:
            return "Qwen"
    elif org == "meta-llama":
        if "Llama-3" in model or "Meta-Llama-3" in model:
            return "Llama-3"
        elif "Llama-2" in model:
            return "Llama-2"
        else:
            return "Llama"
    elif org == "deepseek-ai":
        if "deepseek-coder" in model:
            return "DeepSeek-Coder"
        elif "DeepSeek-R1" in model:
            return "DeepSeek-R1"
        else:
            return "DeepSeek"
    elif org == "openai":
        return "GPT-OSS"
    elif org == "mistralai":
        return "Mistral"
    elif org == "01-ai":
        return "Yi"
    elif org == "baichuan-inc":
        return "Baichuan"
    elif org == "ibm-granite":
        if "granite" in model:
            return "Granite"
        else:
            return "ibm"
    elif org == "google":
        if "gemma-2" in model:
            return "Gemma-2"
        elif "gemma-3" in model:
            return "Gemma-3"
        else:
            return "Gemma"
    elif org == "LGAI-EXAONE":
        return "Exaone"
    elif org == "CohereLabs":
        return "Cohere"
    elif org == "moonshotai":
        return "Kimi"
    elif org == "tiiuae":
        return "falcon"
    elif org == "microsoft":
        return "phi"
    elif org == "bigcode":
        if "starcoder2" in model:
            return "starcoder2"
        else:
            return "starcoder"
    else:
        return org


def get_model_size(hf_model_name: str) -> Optional[float]:
    """
    Get model size in billions of parameters from HuggingFace.
    
    Args:
        hf_model_name: HuggingFace model name (e.g., "Qwen/Qwen3-14B")
        
    Returns:
        Model size in billions of parameters, or None if not found
    """
    try:
        info = model_info(hf_model_name)
        if "Baichuan2-7B" in hf_model_name:
            return 7.0
        if "moonshotai/Kimi-K2-Instruct" in hf_model_name:
            return 1000
        if "bigcode/starcoderbase" in hf_model_name:
            return 15.5
        if "deepseek-ai/deepseek-coder-1.3b-base" in hf_model_name:
            return 1.3
        if "bigcode/starcoderbase-3b" in hf_model_name:
            return 3.0
        if hasattr(info, 'safetensors') and hasattr(info.safetensors, 'total'):
            # Convert to billions
            size_in_billions = info.safetensors.total / 1e9
            return round(size_in_billions, 1)
        else:
            print(f"Warning: Could not get model size for {hf_model_name}")
            return None
    except Exception as e:
        print(f"Warning: Error getting model size for {hf_model_name}: {e}")
        return None


def load_json_file(file_path: str) -> Dict[str, Any]:
    """
    Load JSON file and return as dictionary.
    
    Args:
        file_path: Path to the JSON file
        
    Returns:
        Dictionary containing the JSON data
        
    Raises:
        FileNotFoundError: If the file doesn't exist
        json.JSONDecodeError: If the file contains invalid JSON
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Warning: File {file_path} not found")
        return {}
    except json.JSONDecodeError as e:
        print(f"Warning: Invalid JSON in {file_path}: {e}")
        return {}


def extract_metrics_from_summary(summary_data: Dict[str, Any], dataset: str = "virtualhome", eval_type: str = "action_sequencing") -> Dict[str, Any]:
    """
    Extract key metrics from summary data for DataFrame columns.
    
    Args:
        summary_data: Summary data dictionary
        dataset: Dataset name to determine metric conversion (default: "virtualhome")
        eval_type: Evaluation type to determine which metrics to extract (default: "action_sequencing")
        
    Returns:
        Dictionary with flattened metrics
    """
    metrics = {}
    
    if not summary_data:
        return metrics
    
    if (eval_type == "goal_interpretation") or ("goal_interpretation" in eval_type):
        if dataset == "behavior":
            # Behavior goal interpretation metrics structure
            # Extract overall metrics
            overall = summary_data.get("overall", {})
            overall_confusion = overall.get("overall_confusion_metrics", {})
            metrics.update({
                "overall_precision": overall_confusion.get("precision"),
                "overall_recall": overall_confusion.get("recall"),
                "overall_f1": overall_confusion.get("f1_score"),
                "num_predicted_conditions": overall.get("num_predicted_conditions"),
                "num_GT_conditions": overall.get("num_GT_conditions"),
                "num_satisfied_conditions": overall.get("num_satisfied_conditions"),
                "num_unsatisfied_conditions": overall.get("num_unsatisfied_conditions"),
                "num_false_positive_conditions": overall.get("num_false_positive_conditions"),
            })
            
            # Extract state goal metrics
            state_goal = summary_data.get("state_goal", {})
            state_goal_confusion = state_goal.get("state_goal_confusion_metrics", {})
            metrics.update({
                "state_goal_precision": state_goal_confusion.get("precision"),
                "state_goal_recall": state_goal_confusion.get("recall"),
                "state_goal_f1": state_goal_confusion.get("f1_score"),
                "state_goal_num_predicted": state_goal.get("num_predicted_conditions"),
                "state_goal_num_GT": state_goal.get("num_GT_conditions"),
                "state_goal_num_satisfied": state_goal.get("num_satisfied_conditions"),
                "state_goal_num_unsatisfied": state_goal.get("num_unsatisfied_conditions"),
                "state_goal_num_false_positive": state_goal.get("num_false_positive_conditions"),
            })
            
            # Extract relation goal metrics
            relation_goal = summary_data.get("relation_goal", {})
            relation_goal_confusion = relation_goal.get("relation_goal_confusion_metrics", {})
            metrics.update({
                "relation_goal_precision": relation_goal_confusion.get("precision"),
                "relation_goal_recall": relation_goal_confusion.get("recall"),
                "relation_goal_f1": relation_goal_confusion.get("f1_score"),
                "relation_goal_num_predicted": relation_goal.get("num_predicted_conditions"),
                "relation_goal_num_GT": relation_goal.get("num_GT_conditions"),
                "relation_goal_num_satisfied": relation_goal.get("num_satisfied_conditions"),
                "relation_goal_num_unsatisfied": relation_goal.get("num_unsatisfied_conditions"),
                "relation_goal_num_false_positive": relation_goal.get("num_false_positive_conditions"),
            })
            
            # Extract grammatical errors metrics
            grammatical_errors = summary_data.get("grammatical_errors", {})
            metrics.update({
                "grammatically_valid_num": grammatical_errors.get("grammatically_valid_num"),
                "grammatically_valid_rate": grammatical_errors.get("grammatically_valid_rate"),
                "format_error_num": grammatical_errors.get("format_error_num"),
                "format_error_rate": grammatical_errors.get("format_error_rate"),
                "state_hallucination_num": grammatical_errors.get("state_hallucination_num"),
                "state_hallucination_rate": grammatical_errors.get("state_hallucination_rate"),
                "object_hallucination_num": grammatical_errors.get("object_hallucination_num"),
                "object_hallucination_rate": grammatical_errors.get("object_hallucination_rate"),
            })
        else:
            # VirtualHome goal interpretation metrics (node, edge, action, and overall precision/recall/F1)
            metrics.update({
                "node_precision": summary_data.get("node_precision"),
                "node_recall": summary_data.get("node_recall"),
                "node_f1": summary_data.get("node_f1"),
                "edge_precision": summary_data.get("edge_precision"),
                "edge_recall": summary_data.get("edge_recall"),
                "edge_f1": summary_data.get("edge_f1"),
                "action_precision": summary_data.get("action_precision"),
                "action_recall": summary_data.get("action_recall"),
                "action_f1": summary_data.get("action_f1"),
                "all_precision": summary_data.get("all_precision"),
                "all_recall": summary_data.get("all_recall"),
                "all_f1": summary_data.get("all_f1"),
            })
    
    elif "action_sequencing" in eval_type:
        # Goal evaluation metrics
        goal_eval = summary_data.get("goal_evaluation", {})
        metrics.update({
            "task_success_rate": goal_eval.get("task_success_rate"),
            "state_goal": goal_eval.get("state_goal"),
            "relation_goal": goal_eval.get("relation_goal"),
            "action_goal": goal_eval.get("action_goal"),
            "total_goal": goal_eval.get("total_goal"),
        })
        
        # Trajectory evaluation metrics
        trajectory_eval = summary_data.get("trajectory_evaluation", {})
        metrics.update({
            "execution_success_rate": trajectory_eval.get("execution_success_rate"),
        })
        
        # Grammar error metrics
        grammar_error = trajectory_eval.get("grammar_error", {})
        metrics.update({
            "parsing_error": grammar_error.get("parsing"),
            "hallucination_error": grammar_error.get("hallucination"),
            "predicate_argument_number_error": grammar_error.get("predicate_argument_number"),
        })
        
        # Runtime error metrics
        runtime_error = trajectory_eval.get("runtime_error", {})
        metrics.update({
            "wrong_order_error": runtime_error.get("wrong_order"),
            "missing_step_error": runtime_error.get("missing_step"),
            "affordance_error": runtime_error.get("affordance_error"),
            "additional_step_error": runtime_error.get("additional_step"),
        })
    
    else:
        # Default case: try to extract any available metrics
        print(f"Warning: Unknown evaluation type '{eval_type}'. Attempting to extract all available metrics.")
        # Extract all top-level numeric values
        for key, value in summary_data.items():
            if isinstance(value, (int, float)) and not key.startswith('_'):
                metrics[key] = value
    
    # Convert behavior dataset metrics from 0-1 range to 0-100 percentage range
    if dataset == "behavior":
        for key in metrics:
            if metrics[key] is not None and isinstance(metrics[key], (int, float)):
                # Only convert rate/precision/recall/f1 metrics (those between 0-1)
                if any(keyword in key.lower() for keyword in ['rate', 'precision', 'recall', 'f1']) and 0 <= metrics[key] <= 1:
                    metrics[key] = metrics[key] * 100
    
    return metrics


def process_evaluation_results(base_path: str, dataset: str = "virtualhome", eval_type: str = "action_sequencing") -> pd.DataFrame:
    """
    Process all model directories in the evaluation results folder.
    
    Args:
        base_path: Path to the evaluation results directory
        dataset: Dataset name (default: "virtualhome")
        eval_type: Evaluation type (default: "action_sequencing")
        
    Returns:
        DataFrame with model metrics and metadata
    """
    results_data = []
    base_dir = Path(base_path)
    
    if not base_dir.exists():
        print(f"Error: Directory {base_path} does not exist")
        return pd.DataFrame()
    
    if dataset == "virtualhome":
        # VirtualHome file structure: each model has its own directory with error_info.json and summary.json
        for item in base_dir.iterdir():
            if item.is_dir():
                directory_name = item.name
                hf_model_name = convert_directory_name_to_hf_model(directory_name)
                
                print(f"Processing {directory_name} -> {hf_model_name}")
                
                # Check for error_info.json and summary.json in the model directory
                error_info_path = item / "error_info.json"
                summary_path = item / "summary.json"
                
                error_info = {}
                summary_info = {}
                
                # Load error_info.json if it exists
                if error_info_path.exists():
                    error_info = load_json_file(str(error_info_path))
                
                # Load summary.json if it exists
                if summary_path.exists():
                    summary_info = load_json_file(str(summary_path))
                
                # Get model size
                model_size = get_model_size(hf_model_name)
                
                # Extract metrics from summary
                metrics = extract_metrics_from_summary(summary_info, dataset="virtualhome", eval_type=eval_type)
                
                # Create row data
                row_data = {
                    "Model": hf_model_name,  # Changed from "model_name" to "Model"
                    "Model Family": get_model_family(hf_model_name),  # Added Model Family column
                    "dataset": dataset,
                    "eval_type": eval_type,
                    "Model Size (B)": model_size,
                    **metrics
                }
                
                results_data.append(row_data)
    
    elif dataset == "behavior":
        # Behavior file structure varies by evaluation type:
        # For action_sequencing:
        #   - error files are in base_dir/log/ with names like model_name_outputs.json
        #   - summary files are in base_dir/summary/ with names like model_name_outputs.json
        # For goal_interpretation:
        #   - log files are in base_dir/log/detailed_analyses/ with names like model_name_detailed_analysis.json
        #   - summary files are in base_dir/summary/ with names like model_name_performance_scores.json
        
        log_dir = base_dir / "log"
        summary_dir = base_dir / "summary"
        
        if not summary_dir.exists():
            print(f"Warning: Summary directory {summary_dir} does not exist")
            return pd.DataFrame()
        
        # Get all model names based on evaluation type
        model_names = set()
        
        if eval_type == "goal_interpretation" or "goal_interpretation" in eval_type:
            # Goal interpretation: summary files are named {model_name}_performance_scores.json
            if summary_dir.exists():
                for summary_file in summary_dir.glob("*_performance_scores.json"):
                    model_name = summary_file.stem.replace("_performance_scores", "")
                    model_names.add(model_name)
            
            # Log files are in detailed_analyses subdirectory
            detailed_analyses_dir = log_dir / "detailed_analyses"
            if detailed_analyses_dir.exists():
                for log_file in detailed_analyses_dir.glob("*_detailed_analysis.json"):
                    model_name = log_file.stem.replace("_detailed_analysis", "")
                    model_names.add(model_name)
            elif log_dir.exists():
                print(f"Warning: Expected detailed_analyses directory at {detailed_analyses_dir} does not exist")
        
        else:
            # Action sequencing: files are named {model_name}_outputs.json
            if summary_dir.exists():
                for summary_file in summary_dir.glob("*_outputs.json"):
                    model_name = summary_file.stem.replace("_outputs", "")
                    model_names.add(model_name)
            
            if log_dir.exists():
                for log_file in log_dir.glob("*_outputs.json"):
                    model_name = log_file.stem.replace("_outputs", "")
                    model_names.add(model_name)
            else:
                print(f"Warning: Log directory {log_dir} does not exist")
        
        if not model_names:
            print("No model files found in expected directories")
            return pd.DataFrame()
        
        # Process each model
        for model_name in sorted(model_names):
            print(f"Processing behavior model: {model_name}")
            
            # Convert to HuggingFace format if needed
            hf_model_name = convert_directory_name_to_hf_model(model_name)
            
            # Load error/log info based on evaluation type
            error_info = {}
            if eval_type == "goal_interpretation" or "goal_interpretation" in eval_type:
                # Goal interpretation: log files are in detailed_analyses subdirectory
                detailed_analyses_dir = log_dir / "detailed_analyses"
                error_file_path = detailed_analyses_dir / f"{model_name}_detailed_analysis.json"
                if error_file_path.exists():
                    error_info = load_json_file(str(error_file_path))
            else:
                # Action sequencing: log files are directly in log directory
                error_file_path = log_dir / f"{model_name}_outputs.json"
                if error_file_path.exists():
                    error_info = load_json_file(str(error_file_path))
            
            # Load summary info based on evaluation type
            summary_info = {}
            if eval_type == "goal_interpretation" or "goal_interpretation" in eval_type:
                # Goal interpretation: summary files are named {model_name}_performance_scores.json
                summary_file_path = summary_dir / f"{model_name}_performance_scores.json"
            else:
                # Action sequencing: summary files are named {model_name}_outputs.json
                summary_file_path = summary_dir / f"{model_name}_outputs.json"
            
            if summary_file_path.exists():
                summary_info = load_json_file(str(summary_file_path))
            
            # Get model size
            model_size = get_model_size(hf_model_name)
            
            # Extract metrics from summary
            metrics = extract_metrics_from_summary(summary_info, dataset="behavior", eval_type=eval_type)
            
            # Create row data
            row_data = {
                "Model": hf_model_name,  # Changed from "model_name" to "Model"
                "Model Family": get_model_family(hf_model_name),  # Added Model Family column
                "dataset": dataset,
                "eval_type": eval_type,
                "Model Size (B)": model_size,
                **metrics
            }
            
            results_data.append(row_data)
    
    else:
        print(f"Warning: Unknown dataset '{dataset}'. Using virtualhome structure as fallback.")
        # Fallback to virtualhome structure for unknown datasets
        for item in base_dir.iterdir():
            if item.is_dir():
                directory_name = item.name
                hf_model_name = convert_directory_name_to_hf_model(directory_name)
                
                print(f"Processing {directory_name} -> {hf_model_name}")
                
                # Check for error_info.json and summary.json
                error_info_path = item / "error_info.json"
                summary_path = item / "summary.json"
                
                error_info = {}
                summary_info = {}
                
                # Load error_info.json if it exists
                if error_info_path.exists():
                    error_info = load_json_file(str(error_info_path))
                
                # Load summary.json if it exists
                if summary_path.exists():
                    summary_info = load_json_file(str(summary_path))
                
                # Get model size
                model_size = get_model_size(hf_model_name)
                
                # Extract metrics from summary
                metrics = extract_metrics_from_summary(summary_info, dataset=dataset, eval_type=eval_type)
                
                # Create row data
                row_data = {
                    "Model": hf_model_name,  # Changed from "model_name" to "Model"
                    "Model Family": get_model_family(hf_model_name),  # Added Model Family column
                    "dataset": dataset,
                    "eval_type": eval_type,
                    "Model Size (B)": model_size,
                    **metrics
                }
                
                results_data.append(row_data)
    
    # Create DataFrame
    df = pd.DataFrame(results_data)
    
    # Reorder columns to put metadata first
    metadata_cols = ["Model", "Model Family", "dataset", "eval_type", "Model Size (B)"]  # Updated column names
    metric_cols = [col for col in df.columns if col not in metadata_cols]
    df = df[metadata_cols + metric_cols]
    
    return df


def main(dataset: str = "virtualhome", eval_type: str = "action_sequencing"):
    """
    Main function to process evaluation results for different datasets and evaluation types.
    
    Args:
        dataset: Dataset name (default: "virtualhome")
        eval_type: Evaluation type (default: "action_sequencing")
    """
    # Path to the evaluation results directory
    base_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/quest_ws/results/{dataset}/evaluate_results/{eval_type}"
    
    print(f"Processing results from: {base_path}")
    print(f"Dataset: {dataset}")
    print(f"Evaluation type: {eval_type}")
    
    # Process all directories
    df = process_evaluation_results(base_path, dataset, eval_type)
    
    if df.empty:
        print("No results found!")
        return df
    
    # Print summary
    print(f"\nProcessed {len(df)} models:")
    print(f"DataFrame shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    
    # Display first few rows
    print("\nFirst few rows:")
    print(df.head())
    
    # Create output directory if it doesn't exist
    output_dir = "eval_results"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save results to CSV file
    output_path = f"{output_dir}/{dataset}_{eval_type}_results.csv"
    df.to_csv(output_path, index=False)
    print(f"\nResults saved to: {output_path}")
    
    # Also save as JSON for backup
    json_output_path = f"{output_dir}/{dataset}_{eval_type}_results.json"
    df.to_json(json_output_path, orient='records', indent=2)
    print(f"Results also saved to: {json_output_path}")

    # Print warning info of models without model size
    models_without_size = df[df['Model Size (B)'].isna()]
    if len(models_without_size) > 0:
        print(f"\n⚠️  WARNING: {len(models_without_size)} models without model size data:")
        print("=" * 80)
        for _, row in models_without_size.iterrows():
            print(f"  • {row['Model']} (Family: {row['Model Family']})")
        print("=" * 80)
        print("These models may not display properly in scaling plots.")
        print("Consider checking HuggingFace model names or adding manual size mappings.")
    else:
        print(f"\n✅ All {len(df)} models have model size data.")
    
    return df


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process evaluation results for different datasets and evaluation types")
    parser.add_argument("--dataset", type=str, default="virtualhome", 
                       help="Dataset name (default: virtualhome)")
    parser.add_argument("--eval-type", type=str, default="action_sequencing", 
                       help="Evaluation type: action_sequencing, goal_interpretation, etc. (default: action_sequencing)")
    
    args = parser.parse_args()
    
    results_df = main(dataset=args.dataset, eval_type=args.eval_type)
