import argparse
import json
import logging
import os
import warnings
import re
import random
import Levenshtein
import torch
import torch.nn.functional as F
import hashlib

from .utils_carve.configs import ApibenchDataConfig, MLLMDataConfig, EvalConfig, HuggingBench1DataConfig, HuggingBench2DataConfig
from .utils_carve.config_loader import create_eval_config_from_yaml
from .openmodel_carve import LoRAModelManager
from .utils_carve.retrieval_replay import PromptReplayBuffer
from .utils_carve.prepareDataset import (
    dict_retriever, 
    gorilla_prompt, 
    gorilla_prompt_explanation_json, 
    gorilla_prompt_explanation,
    create_gorilla_prompt_with_date,
    create_gorilla_prompt_explanation_json_with_date,
    create_gorilla_prompt_explanation_with_date,
    gorilla_fewshot_prompt,
    gorilla_fewshot_prompt_explanation_json,
    gorilla_fewshot_prompt_explanation,
    create_gorilla_fewshot_prompt_with_date,
    create_gorilla_fewshot_prompt_explanation_json_with_date,
    create_gorilla_fewshot_prompt_explanation_with_date,
    truncate_text_by_tokens
)
from .utils_carve.retrieval_replay import ExperienceIndex
from .utils.wandb import WandbLogger
from .utils.prepareDataset import load_dataset_json
from pathlib import Path
from dotenv import load_dotenv

PACKAGE_ROOT = Path(__file__).resolve().parent
PROJECT_ROOT = PACKAGE_ROOT.parent   
load_dotenv(PROJECT_ROOT / ".env")     


# Suppress all unnecessary logging
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("peft").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

def reconstruct_partial_dict(text: str):
    # Step 1: Clean up common trailing characters
    text = text.strip()
    text = text.replace("]}", "}")
    
    # Step 2: Use regex to find valid "key": "value" pairs
    pairs = re.findall(r'"([^"]+)"\s*:\s*"([^"]*)"', text)

    # Step 3: Build a dictionary from the pairs
    reconstructed = {k: v for k, v in pairs}
    
    return reconstructed

def extract_model_name_from_dict(llm_answer: str) -> str:
    """
    Extracts the model name from the LLM answer string.
    Args:
        llm_answer (str): The answer string generated by the LLM.

    Returns:
        str: The extracted model name.
    """
    # strip any </s> if present
    llm_answer = llm_answer.strip()
    llm_answer = llm_answer.replace("</s>", "").strip()
    llm_answer = llm_answer.replace("model", "model_n  ame").strip()
 
    try:
        answer_dict = json.loads(llm_answer)
        return answer_dict.get("model_name", "")
    except json.JSONDecodeError:
        answer_dict = reconstruct_partial_dict(llm_answer)
        if "model_name" in answer_dict:
            return answer_dict["model_name"]
        else:
            import logging
            logger = logging.getLogger(__name__)
            logger.warning("'model_name' key not found in the reconstructed dictionary.")
            logger.debug(f"LLM answer: {llm_answer}")
            return ""
    except AttributeError:
        import logging
        logger = logging.getLogger(__name__)
        logger.warning("LLM answer is not a valid string.")
        logger.debug(f"LLM answer: {llm_answer}")
        return ""


def extract_model_name_from_string(llm_answer: str) -> str:
    """
    Extracts the model name from the LLM answer string.
    Args:
        llm_answer (str): The answer string generated by LLM expected format: <<<model_name>>>model name <<<explanation>>>explanation.

    Returns:
        str: The extracted model name.
    """
    # strip any </s> if present
    llm_answer = llm_answer.strip()
    llm_answer = llm_answer.replace("</s>", "").strip()
    #llm_answer = llm_answer.replace("model", "model_name").strip()
 
    try:
        # Use regex to extract the model name between <<<model_name>>> and <<<explanation>>>
        match = re.search(r'<<<model_name>>>(.*?)<<<explanation>>>', llm_answer)
        if match:
            return match.group(1).strip()
        else:
            import logging
            logger = logging.getLogger(__name__)
            logger.warning("'model_name' pattern not found in the LLM answer.")
            logger.debug(f"LLM answer: {llm_answer}")
            return ""
    except AttributeError:
        import logging
        logger = logging.getLogger(__name__)
        logger.warning("LLM answer is not a valid string.")
        logger.debug(f"LLM answer: {llm_answer}")
        return ""

# Function to compute Levenshtein similarity
def levenshtein_similarity(name1, name2):
    max_len = max(len(name1), len(name2))
    if max_len == 0:
        return 1.0  # If both strings are empty, consider them identical
    return 1 - (Levenshtein.distance(name1, name2) / max_len)


def validate_and_fix_response(response, model_names, model_domains=None, 
                              ground_truth_domain=None, threshold=0.6):
    """
    Label snapping: Validate response and fix if not in model names.
    
    This function implements post-processing validation to fix hallucinated
    model names by finding the closest valid model name using Levenshtein similarity.
    If a ground truth domain is provided, it prioritizes models from that domain.
    
    Args:
        response: The generated response (model name)
        model_names: Set of valid model names
        model_domains: Dict mapping model names to domains (optional)
        ground_truth_domain: Expected domain (optional, for domain filtering)
        threshold: Minimum similarity threshold for fixing (default: 0.6)
        
    Returns:
        Fixed response (original if no good match found)
    """
    response = response.strip()
    
    # If already valid, return as-is
    if response in model_names:
        return response
    
    # Find closest match
    best_match = None
    best_sim = 0
    
    # Filter by domain if provided (prioritize correct domain)
    candidate_models = model_names
    if ground_truth_domain and model_domains:
        domain_models = {
            m for m in model_names 
            if model_domains.get(m) == ground_truth_domain
        }
        # If we have domain matches, use those; otherwise fall back to all models
        if domain_models:
            candidate_models = domain_models
    
    # Find best match by Levenshtein similarity
    for model_name in candidate_models:
        sim = levenshtein_similarity(response, model_name)
        if sim > best_sim:
            best_sim = sim
            best_match = model_name
    
    # Return best match if similarity is above threshold
    if best_match and best_sim >= threshold:
        return best_match
    
    # Return original response if no good match found
    return response


def compute_metrics(answers, dataset, enable_label_snapping=True, snapping_threshold=0.5):
    """
    Compute evaluation metrics with optional label snapping.
    
    Args:
        answers: List of answer dictionaries
        dataset: Dataset with model names and domains
        enable_label_snapping: Whether to enable post-processing validation (default: True)
        snapping_threshold: Similarity threshold for label snapping (default: 0.7)
    """
    model_domains = {data['model_name']: data['domain'] for data in dataset}
    model_names = set([ans['model_name'] for ans in dataset])

    count_exist = 0
    same_domain = 0
    count = 0
    fixed_count = 0  # Track how many responses were fixed by label snapping

    for ans in answers:
        original_response = ans['response']
        
        # Apply label snapping if enabled
        if enable_label_snapping:
            ground_truth_domain = ans.get('domain_ground_true')
            ans['response'] = validate_and_fix_response(
                ans['response'],
                model_names,
                model_domains,
                ground_truth_domain,
                snapping_threshold
            )
            # Track if response was fixed
            if ans['response'] != original_response:
                fixed_count += 1
        
        # Compute accuracy metrics
        if ans['response'] == ans['ground_true']:
            count += 1
        
        if ans['response'] in model_names:
            count_exist += 1
            if model_domains[ans['response']] == model_domains[ans['ground_true']]:
                same_domain += 1

    accuracy = count / len(answers)
    accuracy_exist = count_exist / len(answers)
    accuracy_domain = same_domain / len(answers)
    
    metrics = {
        "Accuracy": accuracy,
        "Accuracy Exist": accuracy_exist,
        "Accuracy Domain": accuracy_domain
    }
    
    # Add label snapping statistics if enabled
    if enable_label_snapping:
        metrics["Label Snapping Fixed"] = fixed_count
        metrics["Label Snapping Fix Rate"] = fixed_count / len(answers) if len(answers) > 0 else 0.0
    
    return metrics

def llm_responses(model: LoRAModelManager, question_jsons: list, eval_config: EvalConfig, dataset_config=None, experience_index=None):
    """
    Generate LLM responses for a list of question JSONs.
    Args:
        model (LoRAModelManager): The LoRA model manager instance.
        question_jsons (list): List of question JSON objects.
        eval_config (EvalConfig): Evaluation configuration.
        dataset_config: Optional dataset config (e.g., ApibenchDataConfig) to get model_date_cutoff.
        experience_index: Optional ExperienceIndex for few-shot augmentation.
    Returns:
        list: List of answer JSON objects.
    
    """
    # Check if using few-shot augmentation
    use_fewshot = experience_index is not None and eval_config.retriever is not None
    top_k = eval_config.fewshot_top_k  # Number of retrieved neighbors for few-shot augmentation
    max_card_tokens = eval_config.fewshot_max_card_tokens  # Max tokens for model card snippets
    
    prompts = []
    for q_json in question_jsons:
        prompt_text = q_json.get(
            "instruction", "").strip().replace('\r\n', '\n')

        # Build few-shot augmentation if available
        fewshot_section = ""
        if use_fewshot:
            # Decide whether to use few-shot examples based on dropout probability
            # For evaluation, typically dropout_prob=0.0 (always use examples)
            use_fewshot_for_this_example = random.random() > eval_config.fewshot_dropout_prob
            
            if use_fewshot_for_this_example:
                retrieved_neighbors = experience_index.retrieve(
                    query=prompt_text,
                    top_k=top_k,
                    exclude_example_ids=None  # No self-masking during inference
                )
                
                if retrieved_neighbors:
                    fewshot_section = "\n\n[RELATED EXAMPLES]\n"
                    for i, neighbor in enumerate(retrieved_neighbors, 1):
                        neighbor_prompt = neighbor['prompt']
                        neighbor_model = neighbor['model_id']
                        neighbor_domain = neighbor.get('domain', '')
                        neighbor_card = truncate_text_by_tokens(
                            neighbor.get('model_card_snippet', ''), 
                            max_card_tokens
                        )
                        
                        fewshot_section += f"Example {i}:\n"
                        fewshot_section += f"  Prompt: {neighbor_prompt}\n"
                        fewshot_section += f"  Reference model (for similar case): {neighbor_model}\n"
                        if neighbor_domain:
                            fewshot_section += f"  Domain: {neighbor_domain}\n"
                        if neighbor_card:
                            fewshot_section += f"  Model card: {neighbor_card}\n"
                        fewshot_section += "\n"
        
        # Build augmented prompt
        augmented_prompt = f"[ORIGINAL PROMPT]\n{prompt_text}"
        if fewshot_section:
            augmented_prompt += fewshot_section
        
        # Legacy model_card support (if not using few-shot)
        model_card = ""
        if not use_fewshot and eval_config.retriever is not None:
            try:
                retriever_name = dict_retriever[eval_config.retriever]
                model_card = (
                    " <Reference API>: " + q_json.get(retriever_name, "")).strip().replace('\r\n', '\n')

            except KeyError:
                raise ValueError(
                    f"Retriever '{eval_config.retriever}' is not valid. Chose from: {list(dict_retriever.keys())}")
        
        # Use augmented prompt if few-shot, otherwise use original
        final_prompt_text = augmented_prompt if use_fewshot else prompt_text

        # Get date cutoff from dataset_config if available
        model_date_cutoff = dataset_config.model_date_cutoff if dataset_config and hasattr(dataset_config, 'model_date_cutoff') else None
        use_date = model_date_cutoff is not None
        
        # Determine which prompt template to use, with date if specified
        # Use few-shot variants when few-shot augmentation is enabled
        if use_fewshot:
            if eval_config.system_prompt_format == "gorilla_prompt_explanation_json":
                base_prompt = create_gorilla_fewshot_prompt_explanation_json_with_date(model_date_cutoff) if use_date else gorilla_fewshot_prompt_explanation_json
            elif eval_config.system_prompt_format == "gorilla_prompt_explanation":
                base_prompt = create_gorilla_fewshot_prompt_explanation_with_date(model_date_cutoff) if use_date else gorilla_fewshot_prompt_explanation
            else:
                base_prompt = create_gorilla_fewshot_prompt_with_date(model_date_cutoff) if use_date else gorilla_fewshot_prompt
        else:
            # Use standard prompts when not using few-shot
            if eval_config.system_prompt_format == "gorilla_prompt_explanation_json":
                base_prompt = create_gorilla_prompt_explanation_json_with_date(model_date_cutoff) if use_date else gorilla_prompt_explanation_json
            elif eval_config.system_prompt_format == "gorilla_prompt_explanation":
                base_prompt = create_gorilla_prompt_explanation_with_date(model_date_cutoff) if use_date else gorilla_prompt_explanation
            else:
                base_prompt = create_gorilla_prompt_with_date(model_date_cutoff) if use_date else gorilla_prompt
        
        prompt = (base_prompt + final_prompt_text + model_card +
                  " ###Response: ").strip().replace('\r\n', '\n')
        prompts.append(prompt)

    responses = model.generate_batch_safe(
        prompts,
        do_sample=eval_config.do_sample,
        temperature=eval_config.temperature,
        max_new_tokens=eval_config.max_new_tokens,
        top_p=eval_config.top_p,
        top_k=eval_config.top_k,
        penalty_alpha=eval_config.penalty_alpha,
        batch_size=eval_config.eval_batch_size,
    )

    # rimuove il prompt dai token generati
    cleaned_responses = [
        r[len(prompt):].strip() if r.startswith(prompt) else r
        for r, prompt in zip(responses, prompts)
    ]

    # remove eos
    cleaned_responses = [o.split('</s>')[0].strip() for o in cleaned_responses]

    # costruisce la lista di output JSONL
    ans_jsons = []
    for prompt, resp in zip(question_jsons, cleaned_responses):
        ans_jsons.append({
            # "question_id": idx,
            "questions": prompt['instruction'],
            "response": resp,
            "ground_true": prompt['model_name'],
            "model_source": prompt['model_source'],
            "domain_ground_true": prompt['domain']
        })

    return ans_jsons


def parse_args() -> EvalConfig:
    parser = argparse.ArgumentParser(
        description="Evaluate a LoRA fine-tuned model on a specific dataset",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=""
    )
    parser.add_argument(
        "--config",
        type=str,
        default="configurations/eval_config.yaml",
        help="Path to YAML configuration file",
    )
    parser.add_argument(
        "--retriever",
        type=str,
        required=False,
        default=None,
        choices=["bm25", "sentence_transformer", "splade", "flagembedding"],
        help="Specify which retriever to use",
    )
    parser.add_argument(
        "--experience_name",
        type=str,
        required=True,
        choices=["apibench", "mllm", "hugging-bench-1", "hugging-bench-2"],
        help="Specify which experience (test dataset) to use",
    )

    parser.add_argument(
        "--lora_adapters",
        type=str,
        nargs='+',
        required=True,
        help="Specify which adapters to use (e.g., --lora_adapters adapter1 adapter2 adapter3)",
    )

    parser.add_argument(
        "--merging_strategy",
        type=str,
        required=False,
        choices=["ties", "dare_linear", "arithmetic_mean"],
        help="Specify which merging strategy to use (e.g., --merging_strategy ties)",
    )
    
    parser.add_argument(
        "--weights",
        type=float,
        nargs='+',
        required=False,
        help="Adapter weights for merging strategy (e.g., --weights 1.0 1.0)",
    )
    
    parser.add_argument(
        "--density",
        type=float,
        required=False,
        help="Adapter density for merging strategy (e.g., --density 0.3)",
    )
    
    parser.add_argument(
        "--output_name",
        type=str,
        required=False,
        help="Name of the directory to save the evaluation results",
    )
    parser.add_argument(
        "--eval_on_train",
        action="store_true",
        help="Also run evaluation on the training set (default: off)",
    )
    parser.add_argument(
        "--use_router",
        action="store_true",
        help="Use router evaluation instead of text generation (requires router checkpoint in adapter directory)",
    )
    parser.add_argument(
        "--debug_router_eval",
        action="store_true",
        default=False,
        help="Enable detailed debugging output for router evaluation (default: False)",
    )
    parser.add_argument(
        "--strict_router_load",
        action="store_true",
        default=False,
        help="Use strict=True when loading router weights (default: False)",
    )
    parser.add_argument(
        "--known_domain_mode",
        action="store_true",
        default=False,
        help="Enable known-domain mode: only compare against models within the same domain (default: False)",
    )
    parser.add_argument(
        "--hierarchical_eval",
        action="store_true",
        default=False,
        help="Enable hierarchical (two-stage) evaluation: predict group then model within group (default: False)",
    )
    parser.add_argument(
        "--hierarchy_level",
        type=str,
        default="domain",
        choices=["domain", "parent_group"],
        help="Hierarchy level for hierarchical evaluation: 'domain' or 'parent_group' (default: 'domain')",
    )
    parser.add_argument(
        "--hierarchical_topk",
        type=int,
        default=1,
        help="Number of top groups to consider in hierarchical evaluation (default: 1)",
    )
    parser.add_argument(
        "--hier_domain_score_mode",
        type=str,
        default="logsumexp",
        choices=["logsumexp", "max", "topk_logsumexp", "hybrid"],
        help="Domain scoring strategy for hierarchical evaluation (default: 'logsumexp')",
    )
    parser.add_argument(
        "--hier_domain_topk",
        type=int,
        default=10,
        help="Number of top models for topk_logsumexp/hybrid domain scoring modes (default: 10)",
    )
    parser.add_argument(
        "--hier_domain_hybrid_alpha",
        type=float,
        default=0.5,
        help="Weight for max in hybrid domain scoring mode (default: 0.5)",
    )
    parser.add_argument(
        "--eval_on_train_samples",
        action="store_true",
        default=False,
        help="Load 50 examples from training split and run router evaluation on them (default: False)",
    )

    args = parser.parse_args()
    # Create eval_config from YAML with command line overrides
    config_overrides = {}
    
    # Only add command line args if they are provided (not None)
    if args.experience_name is not None:
        config_overrides['experience_name'] = args.experience_name
    if args.retriever is not None:
        config_overrides['retriever'] = args.retriever
    if args.lora_adapters is not None:
        config_overrides['lora_adapters'] = args.lora_adapters
    if args.merging_strategy is not None:
        config_overrides['lora_merging_strategy'] = args.merging_strategy
    if args.weights is not None:
        config_overrides['ties_or_dare_weights'] = args.weights
    if args.density is not None:
        config_overrides['ties_or_dare_density'] = args.density
    if args.output_name is not None:
        config_overrides['output_name'] = args.output_name
    if args.eval_on_train:
        config_overrides['eval_on_train'] = True
    if args.use_router:
        config_overrides['use_router'] = True
    if args.debug_router_eval:
        config_overrides['debug_router_eval'] = True
    if args.strict_router_load:
        config_overrides['strict_router_load'] = True
    if args.known_domain_mode:
        config_overrides['known_domain_mode'] = True
    if args.hierarchical_eval:
        config_overrides['hierarchical_eval'] = True
    if args.hierarchy_level is not None:
        config_overrides['hierarchy_level'] = args.hierarchy_level
    if args.hierarchical_topk is not None:
        config_overrides['hierarchical_topk'] = args.hierarchical_topk
    if args.hier_domain_score_mode is not None:
        config_overrides['hier_domain_score_mode'] = args.hier_domain_score_mode
    if args.hier_domain_topk is not None:
        config_overrides['hier_domain_topk'] = args.hier_domain_topk
    if args.hier_domain_hybrid_alpha is not None:
        config_overrides['hier_domain_hybrid_alpha'] = args.hier_domain_hybrid_alpha
    if args.eval_on_train_samples:
        config_overrides['eval_on_train_samples'] = True

    return create_eval_config_from_yaml(args.config, **config_overrides)


def main():
    eval_config = parse_args()
    
    # Initialize WandB logger
    wandb_key = os.getenv("WANDB_API_KEY")
    if wandb_key:
        wandb_logger = WandbLogger(wandb_key, eval_config, mode="eval")
    else:
        wandb_logger = None
        import logging
        logger = logging.getLogger(__name__)
        logger.warning("WANDB_API_KEY not found in environment variables. Skipping WandB logging.")
    
    lora_paths = [f"./cco/experiments/{adapter}" for adapter in eval_config.lora_adapters]
    model = LoRAModelManager(eval_config, lora_paths=lora_paths)


    if eval_config.experience_name == "apibench":
        dataset = ApibenchDataConfig()
    elif eval_config.experience_name == "mllm":
        dataset = MLLMDataConfig()
    elif eval_config.experience_name == "hugging-bench-1":
        dataset = HuggingBench1DataConfig()
    elif eval_config.experience_name == "hugging-bench-2":
        dataset = HuggingBench2DataConfig()


    # Check if using few-shot baseline (detect from adapter path or config)
    use_fewshot = False
    experience_index = None
    training_experience = None
    if eval_config.retriever is not None:
        # Check if any adapter path contains "retrieval_replay_fewshot"
        for adapter in eval_config.lora_adapters:
                if "retrieval_replay_fewshot" in adapter.lower():
                    use_fewshot = True
                    # Extract training experience from adapter path
                    # Format: "apibench-retrieval_replay_fewshot_bm25-..."
                    adapter_parts = adapter.split("/")[0]  # Get just the experience name part
                    # Handle experience names with hyphens
                    # Split on "-retrieval" to get the experience name part
                    if "-retrieval" in adapter_parts:
                        training_experience = adapter_parts.split("-retrieval")[0]
                    else:
                        # Fallback: try to extract known experience names
                        for exp_name in ["mllm", "apibench", "hugging-bench-1", "hugging-bench-2"]:
                            if adapter_parts.startswith(exp_name):
                                training_experience = exp_name
                                break
                        if training_experience is None:
                            # Last resort: split on first hyphen
                            training_experience = adapter_parts.split("-")[0]
                    break
    
    # If using few-shot, build experience index from training data + replay buffer
    if use_fewshot:
        if not training_experience:
            import logging
            logger = logging.getLogger(__name__)
            logger.warning("Could not extract training experience from adapter path. Using current dataset.")
            training_experience = eval_config.experience_name
            train_dataset_config = dataset
            train_data = load_dataset_json(train_dataset_config.train_set)
        else:
            print(f"Detected retrieval_replay_fewshot baseline - building experience index for training experience: {training_experience}")
            
            # Load training data for the experience the model was trained on
            if training_experience == "apibench":
                train_dataset_config = ApibenchDataConfig()
            elif training_experience == "mllm":
                train_dataset_config = MLLMDataConfig()
            elif training_experience == "hugging-bench-1":
                train_dataset_config = HuggingBench1DataConfig()
            elif training_experience == "hugging-bench-2":
                train_dataset_config = HuggingBench2DataConfig()
            else:
                train_dataset_config = dataset  # Fallback to current dataset
            
            train_data = load_dataset_json(train_dataset_config.train_set)
        
        # Build replay buffer using the same method as training
        # Use PromptReplayBuffer to replicate exact training behavior
        # Use config values (should match training config for consistency)
        replay_seed = eval_config.fewshot_replay_seed
        replay_ratio = eval_config.fewshot_replay_ratio
        
        replay_buffer = PromptReplayBuffer(
            replay_ratio=replay_ratio,
            seed=replay_seed
        )
        
        # Replicate the cumulative replay buffer construction from training
        # E1 (apibench): no replay buffer (first experience)
        # E2 (mllm): include 10% of apibench

        if training_experience == "apibench":
            # E1: no replay buffer (first experience)
            replay_examples = []
        elif training_experience == "mllm":
            # E2: include 10% of apibench
            apibench_config = ApibenchDataConfig()
            apibench_train = load_dataset_json(apibench_config.train_set)
            replay_buffer.add_experience(apibench_train, "apibench")
            replay_examples = replay_buffer.get_examples()
        
        else:
            # Unknown experience, no replay
            replay_examples = []
        
        experience_index = ExperienceIndex(
            retriever_type=eval_config.retriever,
            current_examples=train_data,
            replay_examples=replay_examples,
            experience_name=training_experience,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
    # If not using few-shot, experience_index remains None (already initialized above)
    
    dataset_json = load_dataset_json(dataset.test_set)
    
    # Load model_family information from alternate source if available
    # For apibench-hf-* files, load from apibench-all-* files which have model_family
    # For other datasets (mllm, hugging-bench), the files themselves contain model_family
    model_family_source_path = None
    if "cleaned-apibench-hf-" in dataset.test_set:
        # Map apibench-hf-* to apibench-all-* (which has family info)
        model_family_source_path = dataset.test_set.replace("cleaned-apibench-hf-", "cleaned-apibench-all-")
    # For mllm, and hugging-bench, the test files themselves contain model_family
    # so we don't need to load from a different file
    
    # Build model_name -> model_family lookup from alternate source if it exists
    family_lookup_from_all = {}
    if model_family_source_path and model_family_source_path != dataset.test_set:
        if os.path.exists(model_family_source_path):
            try:
                family_source_json = load_dataset_json(model_family_source_path)
                for ex in family_source_json:
                    model_name = ex.get('model_name', '')
                    model_family = ex.get('model_family')
                    if model_name and model_family and model_family.strip():
                        family_lookup_from_all[model_name] = model_family
                print(f"  ✓ Loaded {len(family_lookup_from_all)} model families from {model_family_source_path}")
            except Exception as e:
                print(f"  ⚠️  Could not load model_family from {model_family_source_path}: {e}")
        else:
            print(f"  ⚠️  Family source file not found: {model_family_source_path}")
    
    # Check if using router evaluation
    if eval_config.use_router:
        # Router evaluation mode
        from pathlib import Path
        from .eval_router import load_trained_router, evaluate_router
        from .model_selection_carve import ModelRegistry
        import json
        
        # Find router checkpoint - router files are now saved in the checkpoint directory itself.
        # So if adapter is "apibench-router/checkpoint-62", router files should be in "apibench-router/checkpoint-62/".
        # We also check the parent directory for backwards compatibility with older checkpoints.
        router_checkpoint_dir = None
        for adapter in eval_config.lora_adapters:
            adapter_path = Path(f"./cco/experiments/{adapter}")
            
            # First check in the adapter directory itself (checkpoint directory - new behavior)
            router_model_path = adapter_path / "router_model.pt"
            if router_model_path.exists():
                router_checkpoint_dir = adapter_path
                print(f"✓ Found router checkpoint in {router_checkpoint_dir}")
                break
            
            # If not found, check parent directory (for backwards compatibility with older checkpoints)
            parent_path = adapter_path.parent
            router_model_path = parent_path / "router_model.pt"
            if router_model_path.exists():
                router_checkpoint_dir = parent_path
                print(f"✓ Found router checkpoint in parent directory: {router_checkpoint_dir}")
                break
        
        if router_checkpoint_dir is None:
            raise FileNotFoundError(
                "Router evaluation requested but no router checkpoint found. "
                "Expected router_model.pt in the adapter directory or its parent directory. "
                f"Searched in: {[f'./cco/experiments/{a}' for a in eval_config.lora_adapters]}"
            )
        
        # Load model registry first to check if registry size matches checkpoint
        registry_path = router_checkpoint_dir / "model_registry.json"
        registry = ModelRegistry.load(registry_path)
        
        # Load router config to check num_models (before loading router)
        router_config_path = router_checkpoint_dir / "router_config.json"
        router_config = None
        num_models_override = None
        if router_config_path.exists():
            with open(router_config_path, 'r') as f:
                router_config = json.load(f)
            checkpoint_num_models = router_config.get('num_models', len(registry))
            if len(registry) != checkpoint_num_models:
                print(f"⚠️  Registry size mismatch: checkpoint has {checkpoint_num_models} models, "
                      f"current registry has {len(registry)} models")
                print(f"  Using num_models_override={len(registry)} to resize embedding table")
                num_models_override = len(registry)
        
        # Debug output at load time
        debug_enabled = getattr(eval_config, 'debug_router_eval', False)
        if debug_enabled:
            print("\n" + "="*80)
            print("[DEBUG ROUTER EVAL] Router Artifact Information")
            print("="*80)
            
            # Helper function to compute SHA256 checksum
            def compute_checksum(file_path: Path) -> str:
                """Compute first 8 chars of SHA256 checksum."""
                if not file_path.exists():
                    return "FILE_NOT_FOUND"
                with open(file_path, 'rb') as f:
                    sha256_hash = hashlib.sha256(f.read()).hexdigest()
                    return sha256_hash[:8]
            
            # Print absolute paths and checksums
            router_model_path = router_checkpoint_dir / "router_model.pt"
            router_config_path = router_checkpoint_dir / "router_config.json"
            
            print(f"\n[Router Files]")
            print(f"  router_model.pt:")
            print(f"    Absolute path: {router_model_path.resolve()}")
            print(f"    SHA256 (first 8): {compute_checksum(router_model_path)}")
            print(f"  model_registry.json:")
            print(f"    Absolute path: {registry_path.resolve()}")
            print(f"    SHA256 (first 8): {compute_checksum(registry_path)}")
            print(f"  router_config.json:")
            print(f"    Absolute path: {router_config_path.resolve()}")
            print(f"    SHA256 (first 8): {compute_checksum(router_config_path)}")
            
            # Print shapes/metadata
            print(f"\n[Router Metadata]")
            print(f"  num_models in ModelRegistry: {len(registry)}")
            if router_config:
                print(f"  num_models in checkpoint: {router_config.get('num_models', 'unknown')}")
                print(f"  embedding_dim: {router_config.get('embedding_dim', 'unknown')}")
                print(f"  lm_hidden_size: {router_config.get('lm_hidden_size', 'unknown')}")
                print(f"  tau: {router_config.get('tau', 'unknown')}")
                print(f"  pooling: {router_config.get('pooling', 'unknown')}")
            
            # Load router model (will load config internally and handle num_models mismatch)
            strict_router_load = getattr(eval_config, 'strict_router_load', True)  # Default to True for evaluation
            router = load_trained_router(
                checkpoint_dir=router_checkpoint_dir,
                device="cuda" if torch.cuda.is_available() else "cpu",
                strict=strict_router_load,
                num_models_override=num_models_override,
            )
            
            model_embeddings_shape = router.model_embeddings.weight.shape
            print(f"  model_embeddings shape: {model_embeddings_shape}")
            print(f"    (num_models={model_embeddings_shape[0]}, embedding_dim={model_embeddings_shape[1]})")
            
            # Warning about directory mismatch
            print(f"\n[WARNING]")
            print(f"  If router artifacts (router_model.pt, model_registry.json) are loaded from")
            print(f"  a different directory than the trained run (e.g., apibench-router vs mllm-router),")
            print(f"  metrics will be near-random due to weight/registry mismatch.")
            print(f"  Current checkpoint directory: {router_checkpoint_dir.resolve()}")
            print("="*80 + "\n")
        else:
            # Load router model normally if debug is disabled
            strict_router_load = getattr(eval_config, 'strict_router_load', True)  # Default to True for evaluation
            router = load_trained_router(
                checkpoint_dir=router_checkpoint_dir,
                device="cuda" if torch.cuda.is_available() else "cpu",
                strict=strict_router_load,
                num_models_override=num_models_override,
            )

        # ------------------------------------------------------------------
        # Router eval tokenizer + prompt formatting parity with training
        # ------------------------------------------------------------------
        # Match training tokenizer settings (see train_loop.py Lines 441-442)
        tokenizer = model.tokenizer
        if hasattr(tokenizer, "padding_side"):
            tokenizer.padding_side = "right"
        # Some tokenizers expose add_eos_token as attribute, others via config
        if hasattr(tokenizer, "add_eos_token"):
            tokenizer.add_eos_token = False

        # Build system prompt for router eval to mirror convert_to_conversational
        # See cco/utils/prepareDataset.py::convert_to_conversational Lines 171-188
        if getattr(eval_config, "system_prompt", "") != "":
            router_system_prompt = eval_config.system_prompt
        else:
            # Get date cutoff from dataset_config if available
            model_date_cutoff = (
                dataset.model_date_cutoff
                if hasattr(dataset, "model_date_cutoff")
                else None
            )
            use_date = model_date_cutoff is not None

            # Use few-shot style Gorilla prompts for routing, same as training
            if eval_config.system_prompt_format == "gorilla_prompt_explanation_json":
                router_system_prompt = (
                    create_gorilla_fewshot_prompt_explanation_json_with_date(model_date_cutoff)
                    if use_date
                    else gorilla_fewshot_prompt_explanation_json
                )
            elif eval_config.system_prompt_format == "gorilla_prompt_explanation":
                router_system_prompt = (
                    create_gorilla_fewshot_prompt_explanation_with_date(model_date_cutoff)
                    if use_date
                    else gorilla_fewshot_prompt_explanation
                )
            else:
                router_system_prompt = (
                    create_gorilla_fewshot_prompt_with_date(model_date_cutoff)
                    if use_date
                    else gorilla_fewshot_prompt
                )
        
        # Create model_name -> model_family lookup from original dataset and registry
        # This is needed because predicted models might not be in test_data
        # Note: Registry stores original model names (not normalized), so we can use them directly
        from .model_selection_carve.model_registry import normalize_model_name
        model_family_lookup = {}
        
        # First, use the family_lookup_from_all if we loaded it (from -all- version)
        if family_lookup_from_all:
            model_family_lookup.update(family_lookup_from_all)
            # Also add normalized versions
            for model_name, model_family in list(family_lookup_from_all.items()):
                normalized_name = normalize_model_name(model_name)
                if normalized_name != model_name:
                    model_family_lookup[normalized_name] = model_family
        
        # Then, build from original dataset (might have model_family if it's an -all- version)
        for ex in dataset_json:
            model_name = ex.get('model_name', '')
            model_family = ex.get('model_family')
            # Check for both None and empty string
            if model_name and model_family and model_family.strip():
                # Store with original name (as registry uses original names in idx2model)
                if model_name not in model_family_lookup:  # Don't overwrite if already in lookup
                    model_family_lookup[model_name] = model_family
                # Also store with normalized name for case-insensitive lookup
                normalized_name = normalize_model_name(model_name)
                if normalized_name != model_name and normalized_name not in model_family_lookup:
                    model_family_lookup[normalized_name] = model_family
        
        # Also check registry metadata for any additional models
        for idx, metadata in registry.metadata.items():
            model_name = registry.idx2model.get(idx)
            if model_name and model_name not in model_family_lookup:
                family = metadata.get('family') or metadata.get('model_family')
                if family and family.strip():
                    model_family_lookup[model_name] = family
                    # Also store normalized version
                    normalized_name = normalize_model_name(model_name)
                    if normalized_name != model_name:
                        model_family_lookup[normalized_name] = family
        
        # Debug: Print lookup statistics
        if len(model_family_lookup) == 0:
            print(f"  ⚠️  WARNING: model_family_lookup is empty! No model_family found in dataset_json or registry.")
        else:
            print(f"  ✓ Built model_family_lookup with {len(model_family_lookup)} entries (including normalized variants)")
        
        # Convert dataset_json to router evaluation format
        test_data = []
        for ex in dataset_json:
            # Extract prompt text (use 'instruction' field)
            prompt_text = ex.get('instruction', '').strip()
            if not prompt_text:
                continue
            
            # Get model_family from lookup if not in the example
            model_name = ex.get('model_name', '')
            model_family = ex.get('model_family')
            if not model_family and model_name and family_lookup_from_all:
                model_family = family_lookup_from_all.get(model_name)
            
            test_data.append({
                'prompt_text': prompt_text,
                'model_name': model_name,
                'domain': ex.get('domain', 'unknown'),
                'model_family': model_family,  # Add model_family for test examples (from -all- version if needed)
                # Optionally pass through model_card/reference_api so router
                # prompts can include the same prefix as training
                'model_card': ex.get('model_card', ''),
                'reference_api': ex.get('reference_api', ''),
            })
        
        # Evaluate router
        print(f"\n[Router Evaluation] Evaluating on {len(test_data)} examples...")
        print(f"[DEBUG] debug_enabled = {debug_enabled}")
        known_domain_mode = getattr(eval_config, "known_domain_mode", False)
        hierarchical_eval = getattr(eval_config, "hierarchical_eval", False)
        hierarchy_level = getattr(eval_config, "hierarchy_level", "domain")
        hierarchical_topk = getattr(eval_config, "hierarchical_topk", 1)
        hier_domain_score_mode = getattr(eval_config, "hier_domain_score_mode", "logsumexp")
        hier_domain_topk = getattr(eval_config, "hier_domain_topk", 10)
        hier_domain_hybrid_alpha = getattr(eval_config, "hier_domain_hybrid_alpha", 0.5)
        print(f"[DEBUG] known_domain_mode = {known_domain_mode}")
        print(f"[DEBUG] hierarchical_eval = {hierarchical_eval}")
        if hierarchical_eval:
            print(f"[DEBUG] hierarchy_level = {hierarchy_level}, hierarchical_topk = {hierarchical_topk}")
            print(f"[DEBUG] hier_domain_score_mode = {hier_domain_score_mode}, hier_domain_topk = {hier_domain_topk}, hier_domain_hybrid_alpha = {hier_domain_hybrid_alpha}")
        router_metrics = evaluate_router(
            router_model=router,
            model_registry=registry,
            lm_model=model,
            test_data=test_data,
            k_values=[1, 3, 5, 10],
            batch_size=eval_config.eval_batch_size,
            device="cuda" if torch.cuda.is_available() else "cpu",
            debug=debug_enabled,
            eval_use_chat_template=False,
            system_prompt=router_system_prompt,
            checkpoint_dir=router_checkpoint_dir,
            # Ensure eval max_length follows eval config and is comparable
            # to training max_length (see train_config.yaml Line 27)
            max_length=getattr(eval_config, "input_max_length", 512),
            # Enable dual-metric reporting: acc_candidate vs acc_all
            candidate_K=(router_config.get("K_total") if router_config else None),
            router_config=router_config,
            known_domain_mode=known_domain_mode,
            hierarchical_eval=hierarchical_eval,
            hierarchy_level=hierarchy_level,
            hierarchical_topk=hierarchical_topk,
            hier_domain_score_mode=hier_domain_score_mode,
            hier_domain_topk=hier_domain_topk,
            hier_domain_hybrid_alpha=hier_domain_hybrid_alpha,
            model_family_lookup=model_family_lookup,  # Pass family lookup for family accuracy computation
        )
        
        # Generate answers format for saving (using batch processing for efficiency)
        # We'll reuse the same encoding logic from evaluate_router but extract predictions
        answers = []
        router.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        with torch.no_grad():
            batch_size = eval_config.eval_batch_size
            for start_idx in range(0, len(test_data), batch_size):
                end_idx = min(start_idx + batch_size, len(test_data))
                batch = test_data[start_idx:end_idx]
                batch_examples = dataset_json[start_idx:end_idx]
                
                prompts = [ex['prompt_text'] for ex in batch]
                
                # Encode prompts
                tokenizer = model.tokenizer
                inputs = tokenizer(
                    prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                ).to(device)
                
                # Get hidden states
                lm_outputs = model.model.model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    output_hidden_states=True,
                    return_dict=True,
                )
                
                # Extract prompt embeddings using router's encode_prompt method
                # This respects the pooling setting (last_token vs mean) and applies projection
                hidden_states = lm_outputs.hidden_states[-1]  # [B, L, D]
                attention_mask_tensor = inputs['attention_mask']  # [B, L]
                
                # Ensure router model and hidden states are on same device and dtype
                router = router.to(device=hidden_states.device, dtype=hidden_states.dtype)
                
                # In evaluation, all tokens are prompt tokens (no completion)
                # So prompt_mask = attention_mask
                prompt_mask = attention_mask_tensor.to(dtype=torch.bool)  # [B, L]
                
                # Use router's encode_prompt which respects pooling setting and applies projection
                prompt_embeddings = router.encode_prompt(
                    hidden_states=hidden_states,
                    prompt_mask=prompt_mask,
                    debug=False
                )  # [B, embedding_dim] - already projected to router embedding space
                
                # Get top-1 predictions
                all_model_embeddings = router.model_embeddings.weight
                scores = F.cosine_similarity(
                    prompt_embeddings.unsqueeze(1),
                    all_model_embeddings.unsqueeze(0),
                    dim=-1
                )
                top_indices = scores.argmax(dim=-1)
                
                # Convert to answers format
                for i, (ex, pred_idx) in enumerate(zip(batch_examples, top_indices)):
                    predicted_model = registry.idx2model[pred_idx.item()]
                    answers.append({
                        "questions": ex.get('instruction', ''),
                        "response": predicted_model,
                        "ground_true": ex.get('model_name', ''),
                        "model_source": ex.get('model_source', ''),
                        "domain_ground_true": ex.get('domain', 'unknown')
                    })
        
        # Print router metrics
        print("\n[Router Metrics]")
        for key, value in router_metrics.items():
            if isinstance(value, float):
                print(f"  {key}: {value:.4f}")
            else:
                print(f"  {key}: {value}")
        
        # Use router metrics as the main metrics
        metrics = router_metrics
        
        # Optionally evaluate on training samples (50 examples) to distinguish eval pipeline bug vs generalization
        if getattr(eval_config, 'eval_on_train_samples', False):
            print(f"\n{'='*80}")
            print(f"[Router Evaluation on Training Samples]")
            print(f"{'='*80}")
            
            # Load training dataset (use the same dataset config as test)
            dataset_json_train = load_dataset_json(dataset.train_set)
            
            # Sample 50 examples (or all if fewer)
            num_train_samples = min(50, len(dataset_json_train))
            if len(dataset_json_train) > num_train_samples:
                # Randomly sample 50 examples
                import random
                random.seed(42)  # For reproducibility
                train_data_sample = random.sample(dataset_json_train, num_train_samples)
            else:
                train_data_sample = dataset_json_train
            
            print(f"Evaluating router on {len(train_data_sample)} training examples...")
            
            # Convert to router evaluation format
            train_test_data = []
            for ex in train_data_sample:
                # Extract prompt text (use 'instruction' field)
                prompt_text = ex.get('instruction', '').strip()
                if not prompt_text:
                    continue
                
                train_test_data.append({
                    'prompt_text': prompt_text,
                    'model_name': ex.get('model_name', ''),
                    'domain': ex.get('domain', 'unknown')
                })
            
            # Evaluate router on training samples
            train_router_metrics = evaluate_router(
                router_model=router,
                model_registry=registry,
                lm_model=model,
                test_data=train_test_data,
                k_values=[1, 3, 5, 10],
                batch_size=eval_config.eval_batch_size,
                device="cuda" if torch.cuda.is_available() else "cpu",
                debug=debug_enabled,
                known_domain_mode=known_domain_mode,
                hierarchical_eval=hierarchical_eval,
                hierarchy_level=hierarchy_level,
                hierarchical_topk=hierarchical_topk,
                hier_domain_score_mode=hier_domain_score_mode,
                hier_domain_topk=hier_domain_topk,
                hier_domain_hybrid_alpha=hier_domain_hybrid_alpha,
            )
            
            print(f"\n[Router Metrics on Training Samples]")
            for key, value in train_router_metrics.items():
                if isinstance(value, float):
                    print(f"  {key}: {value:.4f}")
                else:
                    print(f"  {key}: {value}")
            
            # Compare test vs train metrics
            print(f"\n[Test vs Train Comparison]")
            print(f"  Test top1_accuracy: {router_metrics.get('top1_accuracy', 0):.4f}")
            print(f"  Train top1_accuracy: {train_router_metrics.get('top1_accuracy', 0):.4f}")
            print(f"  Test entropy_mean: {router_metrics.get('entropy_mean', 0):.4f}")
            print(f"  Train entropy_mean: {train_router_metrics.get('entropy_mean', 0):.4f}")
            print(f"\n  Interpretation:")
            if train_router_metrics.get('top1_accuracy', 0) > 0.5:
                print(f"    ✓ Router performs well on training samples → likely generalization/distribution shift issue")
            else:
                print(f"    ⚠️  Router performs poorly on training samples → likely eval pipeline bug")
            print(f"{'='*80}\n")
            
            # Store train metrics in main metrics dict
            metrics['train_samples'] = train_router_metrics
    else:
        # Standard text generation evaluation
        answers = llm_responses(model, dataset_json, eval_config, dataset_config=dataset, experience_index=experience_index)
        
        metrics: dict = compute_metrics(
            answers, dataset=dataset_json)

    print(metrics)
    
    # Log metrics to WandB
    if wandb_logger:
        wandb_logger.log(metrics)

    # Optionally evaluate on train set to assess overfitting
    if eval_config.eval_on_train:
        dataset_json_train = load_dataset_json(dataset.train_set)
        answers_train = llm_responses(model, dataset_json_train, eval_config, dataset_config=dataset, experience_index=experience_index)
        train_metrics: dict = compute_metrics(answers_train, dataset=dataset_json_train)
        print({"train": train_metrics})
        if wandb_logger:
            wandb_logger.log({f"train/{k}": v for k, v in train_metrics.items()})

    if eval_config.output_name:
        save_path = f"results/{eval_config.output_name}"
    else:
        if len(eval_config.lora_adapters) > 1:
            save_path = f"results/{eval_config.experience_name}/{eval_config.lora_merging_strategy}/" + "_".join([adapter.replace("/", "-") for adapter in eval_config.lora_adapters])
            if eval_config.ties_or_dare_weights:
                save_path += f"/weights-" + "_".join([str(w).replace(".", "-") for w in eval_config.ties_or_dare_weights])
                save_path += f"_density-{eval_config.ties_or_dare_density}".replace(".", "-")
        else:
            save_path = f"results/{eval_config.experience_name}/{eval_config.lora_adapters[0]}"
    
    os.makedirs(save_path, exist_ok=True)

    # save answers to file
    with open(f"{save_path}/answers.jsonl", "w") as f:
        for line in answers:
            f.write(json.dumps(line) + "\n")

    # save metrics as json to a file
    metrics["Adapter Path"] = f"{eval_config.lora_adapters}"
    with open(f"{save_path}/metrics.json", "w") as f:
        if eval_config.lora_merging_strategy in ["ties", "dare_linear"] and len(eval_config.lora_adapters) > 1:
            metrics["Merge Weights"] = eval_config.ties_or_dare_weights
            metrics["Merge Density"] = eval_config.ties_or_dare_density

        json.dump(metrics, f)

    # If train evaluation was run, save its outputs under a subdirectory
    if eval_config.eval_on_train:
        save_path_train = f"{save_path}/train"
        os.makedirs(save_path_train, exist_ok=True)
        with open(f"{save_path_train}/answers.jsonl", "w") as f:
            for line in answers_train:
                f.write(json.dumps(line) + "\n")
        with open(f"{save_path_train}/metrics.json", "w") as f:
            json.dump(train_metrics, f)

    # Finish WandB logging
    if wandb_logger:
        wandb_logger.finish()


if __name__ == "__main__":
    main()
