import os
import argparse
from typing import List, Dict, Any
import pandas as pd

# Import core components
from evaluator import Evaluator
from direct import DirectInference
from rag_index import InvertedIndexRAG
from rag_dense import DenseVectorRAG
from rag_table import StructuralRAG


def load_qa_data(qa_tsv_path: str) -> pd.DataFrame:
    """Load and validate QA dataset"""
    if not os.path.exists(qa_tsv_path):
        raise FileNotFoundError(f"QA dataset does not exist: {qa_tsv_path}")
    
    # Read data and handle possible encoding issues
    try:
        qa_df = pd.read_csv(qa_tsv_path, sep=',')
    except Exception as e:
        raise RuntimeError(f"Failed to read QA data: {str(e)}")
    
    # Check required columns
    required_cols = ['rxn_id', 'mol_id', 'question', 'answer', 'qa_type', 'input_type']
    missing_cols = [col for col in required_cols if col not in qa_df.columns]
    if missing_cols:
        raise ValueError(f"QA dataset missing required columns: {missing_cols}")
    
    # Basic data cleaning
    for col in ['question', 'answer']:
        qa_df[col] = qa_df[col].astype(str).fillna('')
    
    return qa_df


def initialize_direct_inference(config: Dict[str, Any], qa_df: pd.DataFrame, num_few_shot: int) -> DirectInference:
    """Initialize direct inference model"""
    try:
        return DirectInference(
            base_url=config["llm"]["base_url"],
            model_name_or_path=config["llm"]["model_name_or_path"],
            api_key=config["llm"]["api_key"],
            qa_df=qa_df,
            num_few_shot=num_few_shot
        )
    except Exception as e:
        raise RuntimeError(f"Failed to initialize direct inference model: {str(e)}")


def initialize_rag_model(rag_type: str, config: Dict[str, Any], qa_df: pd.DataFrame, 
                        num_few_shot: int, retrieve_k: int) -> Any:
    """Initialize corresponding RAG model based on type"""
    if rag_type == 'inverted_index':
        return _initialize_inverted_index_rag(config, qa_df, num_few_shot)
    elif rag_type == 'dense_vector':
        return _initialize_dense_vector_rag(config, qa_df, num_few_shot)
    elif rag_type == 'structural':
        return _initialize_structural_rag(config, qa_df, num_few_shot)
    else:
        raise ValueError(f"Unsupported RAG type: {rag_type}")


def _initialize_inverted_index_rag(config: Dict[str, Any], qa_df: pd.DataFrame, num_few_shot: int) -> InvertedIndexRAG:
    """Initialize inverted index RAG model"""
    rag_model = InvertedIndexRAG(
        base_url=config["llm"]["base_url"],
        model_name_or_path=config["llm"]["model_name_or_path"],
        api_key=config["llm"]["api_key"],
        stopwords=config["stopwords"],
        qa_df=qa_df,
        num_few_shot=num_few_shot
    )
    
    rag_model.read_tsv_documents(
        reactions_tsv=config["data"]["reactions_tsv"],
        compounds_tsv=config["data"]["compounds_tsv"]
    )
    rag_model.build_index()
    return rag_model


def _initialize_dense_vector_rag(config: Dict[str, Any], qa_df: pd.DataFrame, num_few_shot: int) -> DenseVectorRAG:
    """Initialize dense vector RAG model"""
    rag_model = DenseVectorRAG(
        base_url=config["llm"]["base_url"],
        model_name_or_path=config["llm"]["model_name_or_path"],
        api_key=config["llm"]["api_key"],
        vector_server_url=config["vector_server"]["url"],
        qa_df=qa_df,
        num_few_shot=num_few_shot
    )
    
    rag_model.read_tsv_documents(
        reactions_tsv=config["data"]["reactions_tsv"],
        compounds_tsv=config["data"]["compounds_tsv"]
    )
    rag_model.build_index()
    return rag_model


def _initialize_structural_rag(config: Dict[str, Any], qa_df: pd.DataFrame, num_few_shot: int) -> StructuralRAG:
    """Initialize structural RAG model"""
    rag_model = StructuralRAG(
        llm_base_url=config["llm"]["base_url"],
        llm_model_path=config["llm"]["model_name_or_path"],
        llm_api_key=config["llm"]["api_key"],
        qa_df=qa_df,
        num_few_shot=num_few_shot
    )
    
    rag_model.load_knowledge_base(
        compound_tsv=config["data"]["compounds_tsv"],
        reaction_tsv=config["data"]["reactions_tsv"]
    )
    return rag_model


def get_output_path(run_mode: str, rag_type: str, config: Dict[str, Any]) -> str:
    """Get output file path"""
    if run_mode == 'direct':
        return config["output"]["direct_eval_result"]
    elif run_mode == 'rag':
        rag_output_map = {
            'inverted_index': config["output"]["rag_inverted_eval_result"],
            'dense_vector': config["output"]["rag_dense_eval_result"],
            'structural': config["output"]["rag_structural_eval_result"]
        }
        return rag_output_map.get(rag_type, f"eval_results/rag_{rag_type}_evaluation.tsv")
    raise ValueError(f"Unsupported run mode: {run_mode}")


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Chemical question answering system evaluation')
    parser.add_argument('--mode', type=str, required=True, choices=['direct', 'rag'], 
                      help='Run mode: direct (direct inference) or rag (retrieval augmented)')
    parser.add_argument('--rag-type', type=str, default='inverted_index', 
                      choices=['inverted_index', 'dense_vector', 'structural'],
                      help='RAG type: inverted_index, dense_vector, or structural')
    parser.add_argument('--num-few-shot', type=int, default=3, 
                      help='Number of few-shot examples')
    parser.add_argument('--retrieve-k', type=int, default=5, 
                      help='Number of RAG retrieval documents')
    parser.add_argument('--retrieve-scoring', type=str, default='tfidf', choices=['tfidf', 'tf'],
                      help='Retrieval scoring method for inverted index RAG')
    parser.add_argument('--numeric-tolerance', type=float, default=0.5,
                      help='Numeric evaluation tolerance range')
    args = parser.parse_args()
    
    # Configuration parameters
    CONFIG = {
        "stopwords": {  # Stopwords collection
        },
        "llm": {  # LLM configuration
            "base_url": "http://localhost:10000/v1",
            # "base_url": "http://localhost:10001/v1",
            # "base_url": "http://localhost:10002/v1",
            # "base_url": "https://openai.com/v1",
            "model_name_or_path": "/home/share/ckpt/Qwen3-8B",  
            # "model_name_or_path": "/home/share/ckpt/Meta-Llama-3.1-8B-Instruct",  
            # "model_name_or_path": "/home/share/ckpt/ChemLLM-7B-Chat-1_5-SFT",  
            # "model_name_or_path": "DeepSeek-R1",
            # "model_name_or_path": "gpt-4o",
            "api_key": "xxxxx"
        },
        "vector_server": {  # Vector service configuration
            "url": "http://localhost:8999/embed"
        },
        "data": {  # Data paths
            "reactions_tsv": "kb_and_qas/reactions_samples.tsv",
            "compounds_tsv": "kb_and_qas/compounds_samples.tsv",
            "qa_tsv": "kb_and_qas/all_qa_pairs_samples.csv"
        },
        "output": {  # Output paths
            "direct_eval_result": "eval_results/direct_evaluation.tsv",
            "rag_inverted_eval_result": "eval_results/rag_inverted_evaluation.tsv",
            "rag_dense_eval_result": "eval_results/rag_dense_evaluation.tsv",
            "rag_structural_eval_result": "eval_results/rag_structural_evaluation.tsv"  # New structural RAG output path
        },
        "wandb": {  # WandB configuration
            "project": "chemical-qa-evaluation"
        }
    }
    
    # Create output directory
    os.makedirs(os.path.dirname(CONFIG["output"]["direct_eval_result"]), exist_ok=True)
    
    try:
        # Load QA data
        print("="*80)
        print("Loading QA data...")
        qa_df = load_qa_data(CONFIG["data"]["qa_tsv"])
        # qa_df = qa_df[qa_df['qa_type'] == 'molecule_captioning']
        print(f"Successfully loaded QA data with {len(qa_df)} samples")
        
        # Extract model identification information
        model_name = CONFIG["llm"]["model_name_or_path"].split('/')[-1]
        run_mode = args.mode
        model_type = f"rag-{args.rag_type}" if run_mode == 'rag' else run_mode
        
        # Initialize evaluator
        evaluator = Evaluator(
            wandb_project=CONFIG["wandb"]["project"],
            model_type=model_type,
            model_name=model_name,
            numeric_tolerance=args.numeric_tolerance
        )
        
        # Get output path
        output_path = get_output_path(run_mode, args.rag_type, CONFIG)
        
        if run_mode == 'direct':
            # Direct LLM inference mode
            print("\n" + "="*80)
            print("Initializing direct inference model...")
            inference_model = initialize_direct_inference(CONFIG, qa_df, args.num_few_shot)
            
            # Initialize WandB
            evaluator.init_wandb(config={
                "num_few_shot": args.num_few_shot,
                "numeric_tolerance": args.numeric_tolerance,
                "data_path": CONFIG["data"]["qa_tsv"],
                "output_path": output_path
            })
            
            # Execute evaluation
            print("\n" + "="*80)
            print("Starting direct LLM evaluation...")
            eval_result_df = evaluator.process_qa(
                qa_df=qa_df,
                prediction_function=inference_model.predict,
                output_path=output_path,
                include_recall=False
            )
            
            # Calculate and record metrics
            eval_metrics = evaluator.calculate_metrics(eval_result_df, include_recall=False)
            evaluator.print_metrics(eval_metrics, include_recall=False)
            evaluator.log_metrics_to_wandb(eval_metrics, include_recall=False)
            
        elif run_mode == 'rag':
            # RAG mode
            print("\n" + "="*80)
            print(f"Initializing {args.rag_type} type RAG model...")
            rag_model = initialize_rag_model(
                rag_type=args.rag_type,
                config=CONFIG,
                qa_df=qa_df,
                num_few_shot=args.num_few_shot,
                retrieve_k=args.retrieve_k
            )
            
            # Initialize WandB
            evaluator.init_wandb(config={
                "retrieve_k": args.retrieve_k,
                "retrieve_scoring": args.retrieve_scoring if args.rag_type == 'inverted_index' else None,
                "num_few_shot": args.num_few_shot,
                "numeric_tolerance": args.numeric_tolerance,
                "data_paths": CONFIG["data"],
                "output_path": output_path,
                "vector_server": CONFIG["vector_server"]["url"] if args.rag_type == 'dense_vector' else None
            })
            
            # Execute evaluation
            print("\n" + "="*80)
            print(f"Starting RAG evaluation based on {args.rag_type}...")
            eval_result_df = evaluator.process_qa(
                qa_df=qa_df,
                prediction_function=rag_model.predict,
                output_path=output_path,
                include_recall=True,
                retrieve_k=args.retrieve_k
            )
            
            # Calculate and record metrics
            eval_metrics = evaluator.calculate_metrics(eval_result_df, include_recall=True)
            evaluator.print_metrics(eval_metrics, include_recall=True)
            evaluator.log_metrics_to_wandb(eval_metrics, include_recall=True)
    
    except Exception as e:
        print(f"Error during execution: {str(e)}")
        if 'evaluator' in locals() and evaluator.wandb_run:
            evaluator.wandb_run.alert(title=f"{model_type} evaluation failed", text=f"Error message: {str(e)}")
        raise  # Re-raise exception for debugging
    finally:
        # Ensure WandB run ends
        if 'evaluator' in locals():
            evaluator.finish_wandb()
    
    # Run completed
    print("\n" + "="*80)
    print(f"{model_type} evaluation completed!")
    print(f"Evaluation result file: {output_path}")
    print("="*80)


if __name__ == '__main__':
    main()
