import os
import sys
import json
import argparse
from typing import List, Dict
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Try to import PEFT for LoRA support
try:
    from peft import PeftModel
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    print("Warning: PEFT not installed. LoRA evaluation will not be available.")

# Add paths for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
sys.path.insert(0, os.path.join(parent_dir, 'Preprocessing', 'sft_data_preparation'))
sys.path.insert(0, os.path.join(parent_dir, 'Preprocessing', 'paper_decomposition'))

# Import from common_utils
from common_utils import (
    extract_field, 
    llm_generation, 
    init_llm_client,
    extract_between_markers,
    llm_generation_with_extraction,
    parse_numbered_list,
    parse_numbered_pairs
)

# Import from paper_decomposition_utils
from paper_decomposition_utils import extract_answer_content

# Import prompts
from prompt_store import instruction_prompts
from eval_prompt_store import eval_instruction_prompts


class HypothesisCompositionEvaluator:
    def __init__(
        self,
        model_path: str,
        lora_path: str = None,
        device: str = "cuda",
        load_in_8bit: bool = False,
        max_length: int = 16384,
        max_new_tokens: int = 4096,
        eval_dataset_path: str = None,
        sft_qa_data_dir: str = None,
        api_type: int = 0,
        api_key: str = "",
        base_url: str = "",
        model_name: str = "r1-distill-qwen-32b",
        # Generation parameters
        temperature: float = 0.6,
        top_p: float = 0.9,
        repetition_penalty: float = 1.2,
        # Hypothesis extraction option
        extract_hypothesis_only: bool = False,
        # Batch generation for better GPU utilization
        batch_size: int = 1
    ):
        """
        Initialize the evaluator with a model and API client.
        
        Args:
            model_path: Path to base model (for generating hypothesis)
            lora_path: Path to LoRA checkpoint (None for base model evaluation)
            device: Device to use (cuda/cpu)
            load_in_8bit: Whether to load model in 8-bit precision
            max_length: Maximum sequence length
            max_new_tokens: Maximum new tokens to generate (default 4096, training max is ~2900)
            eval_dataset_path: Path to evaluation dataset directory
            sft_qa_data_dir: Path to SFT QA data directory
            api_type: API type (0: OpenAI, 1: Azure, 2: Google)
            api_key: API key for evaluation
            base_url: Base URL for API
            model_name: Model name for API evaluation
            temperature: Generation temperature (default 0.6 for diversity)
            top_p: Top-p sampling parameter
            repetition_penalty: Penalty for token repetition (default 1.2)
            extract_hypothesis_only: If True, use LLM to extract only novel hypothesis 
                                    content before scoring (removes repeated input sections)
        """
        self.device = device if torch.cuda.is_available() else "cpu"
        self.max_length = max_length
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty
        self.eval_dataset_path = eval_dataset_path
        self.sft_qa_data_dir = sft_qa_data_dir
        self.extract_hypothesis_only = extract_hypothesis_only
        self.batch_size = batch_size
        
        # Initialize API client for evaluation
        self.api_type = api_type
        self.model_name = model_name
        self.client = init_llm_client(api_type, api_key, base_url)
        
        # Track extraction failures (not model failures)
        self.extraction_failures = 0
        self.total_evaluations = 0
        
        print(f"Loading model from {model_path}")
        if lora_path:
            print(f"Loading LoRA weights from {lora_path}")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            use_fast=False
        )
        
        # Set padding token if not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model
        if load_in_8bit:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                load_in_8bit=True,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
        
        # Load LoRA weights if provided
        if lora_path:
            if not PEFT_AVAILABLE:
                raise ImportError("PEFT is required for LoRA evaluation. Install with: pip install peft")
            
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_path,
                torch_dtype=torch.bfloat16
            )
            
            # Merge LoRA weights for faster inference
            self.model = self.model.merge_and_unload()
            print("LoRA weights loaded and merged")
        
        self.model.eval()
        print("Model loaded successfully")

    def generate_responses_batch(self, prompts: List[str]) -> List[str]:
        """
        Generate responses for multiple prompts in a batch for better GPU utilization.
        
        Args:
            prompts: List of user prompt contents (without system prompt or formatting)
            
        Returns:
            List of generated response texts (reasoning + hypothesis)
        """
        if not prompts:
            return []
        
        # Format all prompts
        formatted_prompts = []
        for prompt in prompts:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            formatted_prompt += "<｜Assistant｜>"
            formatted_prompts.append(formatted_prompt)
        
        # Tokenize with left padding for batch generation
        self.tokenizer.padding_side = 'left'
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        inputs = self.tokenizer(
            formatted_prompts,
            return_tensors="pt",
            max_length=self.max_length,
            truncation=True,
            padding=True
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                do_sample=True,
                top_p=self.top_p,
                repetition_penalty=self.repetition_penalty,
                num_beams=1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode responses - use max_input_len (after padding) as the split point
        # BUG FIX: With left padding, we must use the padded length, not the actual token count
        # Old (wrong): input_len = (attention_mask == 1).sum() → this gives actual tokens, not padded position
        # With left padding: [PAD PAD PAD actual_tokens...] → output starts after ALL input positions
        max_input_len = inputs['input_ids'].shape[1]
        responses = []
        for i, output in enumerate(outputs):
            generated_tokens = output[max_input_len:]
            response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            responses.append(response)
        
        return responses

    def extract_hypothesis_content(self, raw_response: str, research_question: str):
        """
        Extract only the novel hypothesis content from the raw response,
        removing sections that repeat the input (research question, background, etc.).
        
        This helps ensure fair comparison between models that output verbose
        responses (repeating input) vs concise hypothesis-only responses.
        
        Args:
            raw_response: The full model-generated response (after </think>)
            research_question: The original research question (for reference)
            
        Returns:
            Extracted hypothesis content without repeated input sections,
            or None if extraction failed (caller should skip evaluation)
        """
        # Get extraction prompts
        extract_prompts = eval_instruction_prompts("hypothesis_extraction")
        
        # Build extraction prompt
        extract_prompt = (
            extract_prompts[0] + research_question +
            extract_prompts[1] + raw_response +
            extract_prompts[2]
        )
        
        # Call LLM to extract hypothesis
        try:
            extraction_result = llm_generation_with_extraction(
                extract_prompt, 
                self.model_name, 
                self.client,
                expected_fields=["Extracted Hypothesis"],
                temperature=0.1,  # Low temperature for consistent extraction
                api_type=self.api_type, 
                max_retries=3
            )
            
            # Get extracted hypothesis
            extracted = extraction_result.get("Extracted Hypothesis", "") if isinstance(extraction_result, dict) else ""
            
            if extracted and extracted.strip():  # Check that extraction returned something
                return extracted.strip()
            else:
                # Return None to signal extraction failure - caller should skip evaluation
                print(f"Warning: Hypothesis extraction returned empty result")
                return None
                
        except Exception as e:
            print(f"Warning: Hypothesis extraction failed: {e}")
            return None

    def evaluate_eval_dataset(self, output_path: str = None):
        """
        Evaluate the model on the full eval dataset.
        
        Args:
            output_path: Path to save evaluation results (will be treated as a folder)
            
        Returns:
            List of evaluation results
        """
        eval_files = [f for f in os.listdir(self.eval_dataset_path) if f.endswith('.json')]
        
        # Process files sequentially
        # For multi-GPU parallelism, use: run_hypothesis_composition_eval_parallel.sh
        eval_results = []
        for cur_file in tqdm(eval_files, desc="Processing files"):
            file_results = self._process_single_file(cur_file)
            eval_results.extend(file_results)
        
        # Calculate overall metrics
        print("\n" + "="*60)
        print("EVALUATION METRICS SUMMARY")
        print("="*60)
        
        # Extract all weighted scores (excluding None values and extraction failures)
        all_weighted_scores = []
        all_component_lengths = []  # Number of decomposed components in ground truth
        all_hypothesis_lengths = []  # Character length of generated hypothesis
        for result in eval_results:
            if isinstance(result, dict) and not result.get('extraction_failed', False):
                if 'weighted_score' in result and result['weighted_score'] is not None:
                    all_weighted_scores.append(result['weighted_score'])
                if 'eval_results' in result and result['eval_results']:
                    all_component_lengths.append(len(result['eval_results']))
                # Track generated hypothesis word count (already extracted via extract_answer_content)
                if 'generated_hypothesis' in result and result['generated_hypothesis']:
                    word_count = len(result['generated_hypothesis'].split())
                    all_hypothesis_lengths.append(word_count)
        
        mean_score = None
        min_score = None
        max_score = None
        avg_component_length = None
        avg_hypothesis_length = None
        
        if all_weighted_scores:
            mean_score = sum(all_weighted_scores) / len(all_weighted_scores)
            min_score = min(all_weighted_scores)
            max_score = max(all_weighted_scores)
            
            print(f"Overall Mean Weighted Score: {mean_score:.4f}")
            print(f"  - This is the PRIMARY METRIC for model comparison")
            print(f"  - Higher is better (range: 0.0 to 1.0)")
            print(f"\nScore Statistics:")
            print(f"  Min score: {min_score:.4f}")
            print(f"  Max score: {max_score:.4f}")
            print(f"  Valid evaluations: {len(all_weighted_scores)}")
            print(f"  Total files processed: {len(eval_results)}")
        else:
            print("No valid weighted scores calculated.")
        
        if all_component_lengths:
            avg_component_length = sum(all_component_lengths) / len(all_component_lengths)
            print(f"\nAverage number of GT components: {avg_component_length:.2f}")
        
        if all_hypothesis_lengths:
            avg_hypothesis_length = sum(all_hypothesis_lengths) / len(all_hypothesis_lengths)
            print(f"Average generated hypothesis length: {avg_hypothesis_length:.1f} words")
        
        # Report extraction failure statistics
        if self.extraction_failures > 0:
            print("\n" + "="*60)
            print("EXTRACTION FAILURE SUMMARY (Not Model Failures)")
            print("="*60)
            print(f"Total evaluations attempted: {self.total_evaluations}")
            print(f"Extraction failures: {self.extraction_failures}")
            print(f"Extraction success rate: {(self.total_evaluations - self.extraction_failures)/self.total_evaluations:.2%}")
            print("\nNote: Extraction failures are API formatting issues, not model quality issues.")
            print("These failures are excluded from the evaluation metrics.")
            
            if self.extraction_failures / self.total_evaluations > 0.1:
                print("\nWARNING: High extraction failure rate (>10%)!")
                print("Consider improving the prompts or extraction logic.")
        elif self.total_evaluations > 0:
            print(f"\n✓ All {self.total_evaluations} evaluations completed successfully with no extraction failures.")
        
        # Save results to folder if output path provided
        if output_path:
            # Treat output_path as a folder
            output_folder = output_path.rstrip('.json')  # Remove .json if present
            os.makedirs(output_folder, exist_ok=True)
            
            # 1. Save metrics.json - only successful evaluations with eval_results and weighted_score
            metrics_data = []
            for result in eval_results:
                if not result.get('extraction_failed', False):
                    metrics_entry = {
                        'file': result['file'],
                        'step_idx': result.get('step_idx', 0),
                        'eval_results': result['eval_results'],
                        'weighted_score': result['weighted_score']
                    }
                    metrics_data.append(metrics_entry)
            
            metrics_path = os.path.join(output_folder, 'metrics.json')
            with open(metrics_path, 'w') as f:
                json.dump(metrics_data, f, indent=2)
            print(f"\nMetrics saved to {metrics_path} ({len(metrics_data)} successful evaluations)")
            
            # 2. Save generations.json - ALL results including extraction failures (for case study)
            generations_data = []
            for result in eval_results:
                gen_entry = {
                    'file': result['file'],
                    'step_idx': result.get('step_idx', 0),
                    'weighted_score': result.get('weighted_score'),
                    'generated_hypothesis': result.get('generated_hypothesis', ''),
                    'reasoning_trace': result.get('reasoning_trace', ''),  # Model's reasoning process
                    'ground_truth_hypothesis': result.get('ground_truth_hypothesis', ''),
                    'extraction_failed': result.get('extraction_failed', False)
                }
                generations_data.append(gen_entry)
            
            generations_path = os.path.join(output_folder, 'generations.json')
            with open(generations_path, 'w') as f:
                json.dump(generations_data, f, indent=2)
            failed_count = sum(1 for r in eval_results if r.get('extraction_failed', False))
            print(f"Generations saved to {generations_path} ({len(generations_data)} total, {failed_count} with extraction failures)")
            
            # 3. Save summary.json - overall statistics
            summary_data = {
                'average_weighted_score': mean_score,
                'average_num_gt_components': avg_component_length,  # Num of decomposed components in ground truth
                'average_hypothesis_length': avg_hypothesis_length,  # Char length of generated hypothesis
                'min_score': min_score,
                'max_score': max_score,
                'total_evaluations': len(all_weighted_scores),
                'total_files_processed': len(eval_results),
                'extraction_failures': self.extraction_failures,
                'total_evaluations_attempted': self.total_evaluations
            }
            
            summary_path = os.path.join(output_folder, 'summary.json')
            with open(summary_path, 'w') as f:
                json.dump(summary_data, f, indent=2)
            print(f"Summary saved to {summary_path}")
            
            print(f"\nAll results saved to folder: {output_folder}")
        
        return eval_results   


    
    def _process_single_file(self, cur_file: str) -> List[Dict]:
        """
        Process a single evaluation file with batched generation for better GPU utilization.
        """
        file_results = []
        
        # Load data
        cur_sft_qa_data_file_path = os.path.join(self.sft_qa_data_dir, cur_file)
        if not os.path.exists(cur_sft_qa_data_file_path):
            print(f"Warning: SFT QA data file not found: {cur_sft_qa_data_file_path}")
            return file_results
            
        with open(cur_sft_qa_data_file_path, "r") as f:
            cur_sft_qa_data = json.load(f)
        cur_research_question = cur_sft_qa_data["research_question"]
        cur_background_survey = cur_sft_qa_data["background_survey"]
        
        cur_eval_data_file_path = os.path.join(self.eval_dataset_path, cur_file)
        with open(cur_eval_data_file_path, "r") as f:
            # MDP_road_with_reasoning_trace: [[insp_id, prev_hyp, found_title, found_abstract, next_hyp, reasoning_trace, hypothesis_label], ...]
            MDP_road_with_reasoning_trace = json.load(f)
        
        # Batch generate all responses first for better GPU utilization
        pre_generated_responses = self._batch_generate_responses(
            cur_research_question, cur_background_survey, MDP_road_with_reasoning_trace
        )
        
        # Evaluate each step (reuse original evaluate_single_MDP_step)
        for step_idx, cur_MDP_step in enumerate(MDP_road_with_reasoning_trace):
            # Pass pre-generated response if available
            pre_response = pre_generated_responses.get(step_idx)
            eval_result, gene_hyp, reasoning, gt_hyp = self.evaluate_single_MDP_step(
                cur_research_question, cur_background_survey, cur_MDP_step, pre_generated_response=pre_response
            )
            
            # Calculate weighted score
            weighted_score = None
            if eval_result is not None and eval_result:
                total_weight = sum(weight for _, weight, _ in eval_result)
                if total_weight > 0:
                    weighted_score = sum(weight * score for _, weight, score in eval_result) / total_weight
            
            if eval_result is not None or gene_hyp is not None:
                file_results.append({
                    'file': cur_file,
                    'step_idx': step_idx,
                    'eval_results': eval_result,
                    'weighted_score': weighted_score,
                    'generated_hypothesis': gene_hyp,
                    'reasoning_trace': reasoning,
                    'ground_truth_hypothesis': gt_hyp,
                    'extraction_failed': eval_result is None
                })
        
        return file_results

    def _batch_generate_responses(self, research_question: str, background_survey: str, 
                                   MDP_steps: List) -> Dict[int, str]:
        """Batch generate responses for valid MDP steps. Returns {step_idx: response}."""
        gen_prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive_v2_delta")
        
        # Build prompts only for valid steps
        valid_indices = []
        all_prompts = []
        for step_idx, cur_MDP_step in enumerate(MDP_steps):
            # Skip invalid steps (same validation as evaluate_single_MDP_step)
            if len(cur_MDP_step) < 7 or cur_MDP_step[6] is None:
                continue
            prev_hyp = cur_MDP_step[1] if cur_MDP_step[1] is not None else "No previous hypothesis."
            found_title = cur_MDP_step[2]
            found_abstract = cur_MDP_step[3]
            prompt = (gen_prompts[0] + research_question + gen_prompts[1] + background_survey + 
                     gen_prompts[2] + prev_hyp + gen_prompts[3] + found_title + 
                     gen_prompts[4] + found_abstract + gen_prompts[5])
            valid_indices.append(step_idx)
            all_prompts.append(prompt)
        
        if not all_prompts:
            return {}
        
        # Generate in batches
        all_responses = []
        for batch_start in range(0, len(all_prompts), self.batch_size):
            batch_prompts = all_prompts[batch_start:batch_start + self.batch_size]
            batch_responses = self.generate_responses_batch(batch_prompts)
            all_responses.extend(batch_responses)
        
        # Map back to original indices
        return {idx: resp for idx, resp in zip(valid_indices, all_responses)}     


    # Function: 
    #   Evaluate the model on a single MDP step
    # Input:
    #   MDP_step: [insp_id, prev_hyp, found_title, found_abstract, next_hyp, reasoning_trace, hypothesis_label]
    #      insp_id: the inspiration id
    #      prev_hyp: the previous hypothesis to incorporate the found inspiration
    #      found_title: the title of the found inspiration
    #      found_abstract: the abstract of the found inspiration
    #      hypothesis_label: the groundtruth hypothesis label (prev_hyp + [found_title, found_abstract])
    #   pre_generated_response: optional pre-generated model response (for batch generation)
    # Output:
    #   Tuple of (eval_result, generated_hypothesis, reasoning_trace, ground_truth_hypothesis)
    #   eval_result: [[component_1, weight_1, score_1], [component_2, weight_2, score_2], ...]
    #   generated_hypothesis: the hypothesis generated by the model
    #   reasoning_trace: the model's reasoning process (content before </think>)
    #   ground_truth_hypothesis: the ground truth hypothesis label
    #   Returns (None, None, None, None) on extraction failure
    def evaluate_single_MDP_step(self, research_question, background_survey, MDP_step, pre_generated_response):
        prev_hyp = MDP_step[1] if MDP_step[1] is not None else "No previous hypothesis."
        found_title = MDP_step[2]
        found_abstract = MDP_step[3]
        
        # Validate hypothesis_label exists (None is caused by extraction failure during reasoning trace generation)
        if len(MDP_step) < 7 or MDP_step[6] is None:
            print(f"Warning: Missing hypothesis_label in MDP_step (len={len(MDP_step)})")
            return None, None, None, None
        hypothesis_label = MDP_step[6]
        
        # Track generated hypothesis and reasoning for output
        gene_hyp = None
        reasoning_trace = None
        
        # Step 1: Initial decomposition of groundtruth hypothesis
        decompose_prompts = eval_instruction_prompts("hypothesis_composition_eval_decompose")
        decompose_prompt = decompose_prompts[0] + research_question + decompose_prompts[1] + background_survey + decompose_prompts[2] + hypothesis_label + decompose_prompts[3]
        
        # Generate and extract with retry logic
        decompose_result = llm_generation_with_extraction(
            decompose_prompt, self.model_name, self.client,
            expected_fields=["Components"],
            temperature=0.1, api_type=self.api_type, max_retries=10
        )
        decompose_response = decompose_result.get("Components", "") if isinstance(decompose_result, dict) else ""
        
        # Step 1b: Refine the decomposition with retry for successful parsing
        refine_prompts = eval_instruction_prompts("hypothesis_composition_eval_decompose_refine")
        refine_prompt = refine_prompts[0] + research_question + refine_prompts[1] + background_survey + refine_prompts[2] + hypothesis_label + refine_prompts[3] + decompose_response + refine_prompts[4]
        
        # Retry up to 10 times to get parseable components (more chances for API to succeed)
        components_with_weights = None
        max_retries = 10
        for attempt in range(max_retries):
            refined_result = llm_generation_with_extraction(
                refine_prompt, self.model_name, self.client,
                expected_fields=["Components"],
                temperature=min(0.1 + attempt * 0.1, 0.5),  # Gradually increase temperature
                api_type=self.api_type, max_retries=10
            )
            components_text = refined_result.get("Components") if isinstance(refined_result, dict) else None
            
            # Try to parse the components and weights
            components_with_weights = parse_numbered_pairs(components_text, separator=':')
            
            if components_with_weights:
                break  # Successfully parsed
            elif attempt < max_retries - 1:
                print(f"Attempt {attempt + 1}/{max_retries}: Failed to parse components, retrying...")
        
        # Track evaluation attempt
        self.total_evaluations += 1
        
        # If still no components after retries, this is an API extraction failure, not model failure
        if not components_with_weights:
            self.extraction_failures += 1
            print(f"ERROR: API extraction failed after {max_retries} attempts (not a model failure)")
            print(f"  Extraction failure rate: {self.extraction_failures}/{self.total_evaluations} = {self.extraction_failures/self.total_evaluations:.2%}")
            print(f"  Problematic hypothesis: {hypothesis_label[:100]}...")
            # Skip this evaluation - don't penalize the model for API formatting issues
            return None, None, None, None  # Return None to indicate extraction failure (different from empty evaluation)
        
        # Simple validation - weights should already be floats from improved prompts
        if not all(isinstance(w, (int, float)) for _, w in components_with_weights):
            # Extraction failure - prompts should ensure proper format
            self.extraction_failures += 1
            print(f"ERROR: Non-numeric weight detected (extraction failure)")
            print(f"  Extraction failure rate: {self.extraction_failures}/{self.total_evaluations} = {self.extraction_failures/self.total_evaluations:.2%}")
            return None, None, None, None
        
        total_weight = sum(w for _, w in components_with_weights)
        print(f"Extracted {len(components_with_weights)} components (total weight: {total_weight:.2f})")
        
        # Step 2: Use pre-generated response (batch generated for GPU efficiency)
        raw_response = pre_generated_response
        
        # Extract reasoning trace (content before </think>)
        think_end_pos = raw_response.find('</think>')
        if think_end_pos != -1:
            reasoning_trace = raw_response[:think_end_pos].strip()
        else:
            reasoning_trace = raw_response  # No </think> found, save full response as reasoning
        
        # Extract the actual hypothesis
        # First try v2 delta format markers: **Delta Hypothesis starts:** ... **Delta Hypothesis ends**
        # extract_between_markers takes (source, label_regex) and auto-adds "starts/ends"
        delta_hyp = extract_between_markers(raw_response, r'Delta\s*Hypothesis')
        if delta_hyp:
            gene_hyp = delta_hyp.strip()
        else:
            # Fallback to existing utility function for other R1 formats
            # Handles <answer> tags, </think> patterns, etc.
            print("Failed to extract delta hypothesis, falling back to extract_answer_content...")
            gene_hyp = extract_answer_content(raw_response)
        
        # Optional: Extract only novel hypothesis content (remove repeated input sections)
        # This helps ensure fair comparison by removing sections that just repeat
        # the research question, background, or inspiration from the input
        if self.extract_hypothesis_only and gene_hyp:
            print("Extracting novel hypothesis content (removing repeated input sections)...")
            extracted_hyp = self.extract_hypothesis_content(gene_hyp, research_question)
            if extracted_hyp is None:
                # Extraction failed - skip evaluation for this sample to maintain consistency
                # We still return the original gene_hyp for logging/debugging purposes,
                # but eval_result=None signals that this sample should be excluded from scoring
                self.extraction_failures += 1
                print(f"ERROR: Hypothesis extraction failed - skipping evaluation for consistency")
                print(f"  Extraction failure rate: {self.extraction_failures}/{self.total_evaluations} = {self.extraction_failures/self.total_evaluations:.2%}")
                # Return: eval_result=None (skip scoring), gene_hyp=original (for logging),
                #         reasoning_trace (for logging), hypothesis_label (ground truth)
                return None, gene_hyp, reasoning_trace, hypothesis_label
            # Use extracted hypothesis for scoring (background/inspiration sections removed)
            gene_hyp = extracted_hyp
        
        # Step 3: Evaluate how gene_hyp covers all components
        compare_prompts = eval_instruction_prompts("hypothesis_composition_eval_compare_all")
        
        if components_with_weights and compare_prompts:
            # Format all components for evaluation
            components_str = "\n".join([f"{i+1}. {comp}" for i, (comp, _) in enumerate(components_with_weights)])
            
            # Initial evaluation
            compare_prompt = compare_prompts[0] + research_question + compare_prompts[1] + background_survey + compare_prompts[2] + components_str + compare_prompts[3] + gene_hyp + compare_prompts[4]
            
            # Generate and extract initial scores
            compare_result = llm_generation_with_extraction(
                compare_prompt, self.model_name, self.client,
                expected_fields=["Scores"],
                temperature=0.1, api_type=self.api_type, max_retries=10
            )
            compare_response = compare_result.get("Scores", "") if isinstance(compare_result, dict) else ""
            
            # Refine the evaluation with retry for successful parsing
            refine_compare_prompts = eval_instruction_prompts("hypothesis_composition_eval_compare_all_refine")
            refine_compare_prompt = refine_compare_prompts[0] + research_question + refine_compare_prompts[1] + background_survey + refine_compare_prompts[2] + components_str + refine_compare_prompts[3] + gene_hyp + refine_compare_prompts[4] + compare_response + refine_compare_prompts[5]
            
            # Retry up to 10 times to get parseable scores
            score_dict = None
            max_retries = 10
            for attempt in range(max_retries):
                refined_result = llm_generation_with_extraction(
                    refine_compare_prompt, self.model_name, self.client,
                    expected_fields=["Scores"],
                    temperature=0.1 if attempt == 0 else 0.3,  # Increase temperature on retries
                    api_type=self.api_type, max_retries=10
                )
                scores_text = refined_result.get("Scores") if isinstance(refined_result, dict) else None
                
                # Try to parse scores
                score_dict = parse_numbered_list(scores_text, value_type='number')
                
                # Validate we have scores for all components
                if score_dict and len(score_dict) == len(components_with_weights):
                    break  # Successfully parsed exact number of scores
                elif attempt < max_retries - 1:
                    print(f"Attempt {attempt + 1}: Incomplete scores ({len(score_dict) if score_dict else 0}/{len(components_with_weights)}), retrying...")
            
            # Build final results - scores should be valid from improved prompts
            if not score_dict or len(score_dict) != len(components_with_weights):
                self.extraction_failures += 1
                print(f"ERROR: Score count mismatch (extraction failure)")
                return None, gene_hyp, reasoning_trace, hypothesis_label
            
            eval_results = []
            for i, (component, weight) in enumerate(components_with_weights):
                score = score_dict.get(i + 1, None)
                if score is None or not isinstance(score, (int, float)):
                    self.extraction_failures += 1
                    print(f"ERROR: Invalid score for component {i+1} (extraction failure)")
                    return None, gene_hyp, reasoning_trace, hypothesis_label
                eval_results.append([component, weight, score])
        else:
            # No components extracted or no comparison prompts - this is an extraction failure
            self.extraction_failures += 1
            print(f"ERROR: Failed to evaluate - no components or prompts (extraction failure)")
            print(f"  Extraction failure rate: {self.extraction_failures}/{self.total_evaluations} = {self.extraction_failures/self.total_evaluations:.2%}")
            return None, gene_hyp, reasoning_trace, hypothesis_label  # Return None to indicate extraction failure
        
        return eval_results, gene_hyp, reasoning_trace, hypothesis_label





if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluate hypothesis composition model')
    
    # Model configuration
    parser.add_argument("--model_path", type=str, required=True, help="Path to base model")
    parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA checkpoint (optional)")
    
    # Evaluation settings
    parser.add_argument("--load_in_8bit", action="store_true", help="Load model in 8-bit precision")
    parser.add_argument("--max_length", type=int, default=16384, help="Maximum sequence length")
    parser.add_argument("--max_new_tokens", type=int, default=4096, help="Maximum new tokens to generate")
    
    # Generation parameters
    parser.add_argument("--temperature", type=float, default=0.6, help="Generation temperature")
    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter")
    parser.add_argument("--repetition_penalty", type=float, default=1.2, help="Repetition penalty")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size for generation (higher = better GPU utilization)")
    
    # Hypothesis extraction option
    parser.add_argument("--extract_hypothesis_only", action="store_true", 
                       help="Use LLM to extract only novel hypothesis content (remove repeated input sections)")
    
    # API settings for evaluation
    parser.add_argument("--model_name", type=str, default="r1-distill-qwen-32b", help="Model name for API evaluation")
    parser.add_argument("--api_type", type=int, default=0, help="0: OpenAI, 1: Azure, 2: Google")
    parser.add_argument("--api_key", type=str, default="", help="API key")
    parser.add_argument("--base_url", type=str, default="", help="Base URL for API")
    
    # Dataset paths
    parser.add_argument("--eval_dataset_path", type=str, required=True, help="Path to evaluation dataset directory")
    parser.add_argument("--sft_qa_data_dir", type=str, required=True, help="Path to SFT QA data directory (for research question and background survey)")
    
    # Output path
    parser.add_argument("--eval_result_path", type=str, required=True, help="Path to save evaluation results")
    
    args = parser.parse_args()
    
    # Initialize evaluator
    evaluator = HypothesisCompositionEvaluator(
        model_path=args.model_path,
        lora_path=args.lora_path,
        device="cuda",
        load_in_8bit=args.load_in_8bit,
        max_length=args.max_length,
        max_new_tokens=args.max_new_tokens,
        eval_dataset_path=args.eval_dataset_path,
        sft_qa_data_dir=args.sft_qa_data_dir,
        api_type=args.api_type,
        api_key=args.api_key,
        base_url=args.base_url,
        model_name=args.model_name,
        temperature=args.temperature,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty,
        extract_hypothesis_only=args.extract_hypothesis_only,
        batch_size=args.batch_size
    )
    
    # Run evaluation
    evaluator.evaluate_eval_dataset(output_path=args.eval_result_path)