import logging
import json
import asyncio
import os
import time
import re
from typing import Dict, List, Set, Tuple, Optional
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import traceback
from vllm import LLM, SamplingParams

# Set up logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize vllm model
try:
    # Model path configuration
    model_path = "Your model path"  # Modify to your local model path
    
    logger.info(f"Loading model from: {model_path}")
    
    # Initialize model using vllm
    llm = LLM(
        model=model_path,
        tensor_parallel_size=1,  # Adjust based on number of GPUs
        trust_remote_code=True,
        max_model_len=4096,  # Adjust maximum length as needed
        gpu_memory_utilization=0.9  # GPU memory utilization
    )
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=512,
        top_p=0.9,
        stop=None
    )
    
    logger.info(f"Successfully loaded vllm model from {model_path}")
except Exception as e:
    logger.error(f"Failed to initialize vllm model: {e}")
    raise

def generate_response(prompt: str, max_tokens: int = 512) -> str:
    """Generate response using vllm"""
    try:
        # Update sampling parameters
        current_sampling_params = SamplingParams(
            temperature=0.7,
            max_tokens=max_tokens,
            top_p=0.9,
            stop=None
        )
        
        # Generate using vllm
        outputs = llm.generate([prompt], current_sampling_params)
        
        # Extract generated text
        if outputs and len(outputs) > 0:
            response = outputs[0].outputs[0].text.strip()
            return response
        else:
            logger.warning("vllm generation result is empty")
            return ""
            
    except Exception as e:
        logger.error(f"Error generating response with vllm: {e}")
        raise

def create_direct_prompt(question_data: Dict) -> str:
    """Create prompt for directly answering the question"""
    return f"""
    Please answer the following questions and select the most appropriate answer from the given options:
    
    Question: {question_data['question']['stem']}
    
    Options:
    {format_choices(question_data['question']['choices'])}
    
    **Note:your answer must conform to the Regular Expression format ^The answer is [A-H]~.+$, for example "The answer is B~getting tired".**
    """

def format_choices(choices: List[Dict]) -> str:
    """Format options list"""
    formatted = []
    for choice in choices:
        formatted.append(f"{choice['label']}. {choice['text']}")
    return "\n".join(formatted)

def create_fact_verification_prompt(question_data: Dict, predicted_answer: str) -> str:
    """Create fact verification prompt"""
    answer_text = get_choice_text_by_label(question_data['question']['choices'], predicted_answer)
    
    return f"""
Verify if the following analysis is factually correct and represents the optimal solution:

Question: {question_data['question']['stem']}
Options: {format_choices(question_data['question']['choices'])}
Selected Answer: {predicted_answer}. {answer_text}

Please analyze:

**Part 1: Factual Correctness**
1. Substitute the selected answer into the question context
2. Analyze whether this answer aligns with general facts and objective laws
3. Check if this answer is consistent with established knowledge and reality

**Part 2: Optimal Solution Analysis**
1. Compare the selected answer with all available options
2. Determine if this answer is the most appropriate choice among all options
3. When multiple valid explanations exist, verify if this is the simplest, most direct, and objective answer (i.e., the answer with the shortest reasoning path)
4. For human-related questions, consider whether physiological states are more direct and objective than emotional states

Provide your verification result in the format: "Fact verification: [PASS/FAIL]~[detailed reason explaining both factual correctness and optimality analysis]"
"""

def create_logic_verification_prompt(question_data: Dict, predicted_answer: str) -> str:
    """Create logic verification prompt"""
    answer_text = get_choice_text_by_label(question_data['question']['choices'], predicted_answer)
    
    return f"""
Perform logical verification for the following analysis using structured logical reasoning:

Question: {question_data['question']['stem']}
Options: {format_choices(question_data['question']['choices'])}
Selected Answer: {predicted_answer}. {answer_text}

Please follow these steps:

**Step 1: Define the Target**
Clearly define what proposition we are trying to verify based on the question.

**Step 2: Identify Premises and Conclusion**
- Determine which statements serve as the foundation for the argument (premises that support the conclusion)
- Confirm the viewpoint or conclusion to be proven based on these premises (final claim)

**Step 3: Analyze Argument Structure**
- Convert the argument (question with substituted answer) into appropriate logical form
- Determine the logical form used (propositional logic, predicate logic, syllogism, etc.)
- Check if the structure is correct

**Step 4: Verify Argument Validity**
- Evaluate whether each premise is true
- Check if the derivation process from premises to conclusion follows logical laws

**Step 5: Consider Counterexamples or Special Cases**
- Try to find examples that could refute the argument
- Look for other possible ways to explain the phenomenon
- Analyze whether these special cases affect the correctness of the conclusion

Based on your analysis, provide your verification result in the format: "Logic verification: [PASS/FAIL]~[detailed explanation of the logical analysis]"
"""

def extract_answer(response: str) -> Optional[str]:
    """Extract option letter from model response"""
    # Try to match "X~content" format
    direct_match = re.search(r'([A-H])~', response, re.IGNORECASE)
    if direct_match:
        return direct_match.group(1)
    
    # Try to match simple "X. content" format
    simple_match = re.search(r'([A-H])\.', response, re.IGNORECASE)
    if simple_match:
        return simple_match.group(1)
    
    return None

def extract_verification_result(response: str, verification_type: str) -> tuple:
    """Extract verification result from verification response"""
    pattern = f'{verification_type} verification:\\s*(PASS|FAIL)\\s*~\\s*(.+)'
    match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
    
    if match:
        result = match.group(1).upper() == "PASS"
        reason = match.group(2).strip()
        return result, reason
    
    return False, ""

def get_choice_text_by_label(choices: List[Dict], label: str) -> str:
    """Get corresponding text content from choices based on option letter"""
    for choice in choices:
        if choice.get('label', '').upper() == label.upper():
            return choice.get('text', '')
    return ''

def verify_answer(question_data: Dict, answer_label: str) -> Dict:
    """Use dual verification method to check if direct answer is correct"""
    try:
        # Get option text
        answer_text = get_choice_text_by_label(question_data['question']['choices'], answer_label)
        
        # 1. Fact verification - verify if answer aligns with facts and objective laws, and if it's the optimal solution
        fact_verification_response = generate_response(
            create_fact_verification_prompt(question_data, answer_label), 
            max_tokens=800
        )
        fact_verified, fact_reason = extract_verification_result(fact_verification_response, "Fact")
        
        # 2. Logic verification - use structured logical reasoning for verification
        logic_verification_response = generate_response(
            create_logic_verification_prompt(question_data, answer_label), 
            max_tokens=1000
        )
        logic_verified, logic_reason = extract_verification_result(logic_verification_response, "Logic")
        
        # Comprehensive verification result: both must pass for verification success
        is_verified = fact_verified and logic_verified
        
        # Return detailed verification result
        return {
            "verified": is_verified,
            "fact": {
                "verified": fact_verified,
                "reason": fact_reason,
                "response": fact_verification_response
            },
            "logic": {
                "verified": logic_verified,
                "reason": logic_reason,
                "response": logic_verification_response
            }
        }
    
    except Exception as e:
        logger.error(f"Error verifying answer: {e}")
        return {"verified": False, "error": str(e)}

def create_retry_prompt(question_data: Dict, verification_result: Dict) -> str:
    """Create retry prompt based on verification failure reasons"""
    # Extract verification failure reasons
    fact_verified = verification_result.get("fact", {}).get("verified", False)
    fact_reason = verification_result.get("fact", {}).get("reason", "")
    logic_verified = verification_result.get("logic", {}).get("verified", False)
    logic_reason = verification_result.get("logic", {}).get("reason", "")
    
    # Build failure reason prompt
    failure_reasons = []
    if not fact_verified and fact_reason:
        failure_reasons.append(f"- **Factual verification failed**: {fact_reason}")
    if not logic_verified and logic_reason:
        failure_reasons.append(f"- **Logical verification failed**: {logic_reason}")
    
    failure_info = "\n".join(failure_reasons) if failure_reasons else "Verification failed, please reconsider"
    
    return f"""
Please reconsider and re-analyze the following question. Your previous answer did not pass verification for the following reasons:

{failure_info}

Please think more carefully and provide a more accurate analysis:

Question: {question_data['question']['stem']}

Options:
{format_choices(question_data['question']['choices'])}

Consider:
1. **Factual Accuracy**: Ensure your answer aligns with known facts, historical information, and objective laws
2. **Optimal Choice**: Select the most appropriate option among all available choices, preferring simpler and more direct answers
3. **Logical Reasoning**: Follow proper logical structure and valid reasoning processes
4. **Evidence-based Analysis**: Base your conclusion on solid evidence and logical derivation

**Note: Your answer must conform to the Regular Expression format ^The answer is [A-H]~.+$, for example "The answer is B~getting tired".**
Please ensure your reasoning is both factually accurate and logically sound.
"""

def process_question(question_data: Dict, retry_count: int = 1, retry_delay: int = 5) -> Dict:
    """Process question with retry mechanism, retry once if verification fails"""
    question_id = question_data.get('id', 'unknown')
    
    # Record verification failure history
    verification_history = []
    
    for attempt in range(retry_count + 1):  # Maximum 2 attempts total (initial + 1 retry)
        try:
            if not question_data.get('question'):
                raise ValueError("Question data cannot be empty")

            # 1. Let model answer directly or re-answer based on verification failure information
            if attempt == 0:
                # First attempt, use standard prompt
                direct_prompt = create_direct_prompt(question_data)
                logger.info(f"Question {question_id} first attempt")
            else:
                # Retry (only one retry chance), include verification failure reasons in prompt
                last_verification = verification_history[-1]
                direct_prompt = create_retry_prompt(question_data, last_verification)
                logger.info(f"Question {question_id} retry (based on verification failure reasons)")
            
            # Generate response using vllm
            direct_response = generate_response(direct_prompt)
            answer_label = extract_answer(direct_response)
            
            if answer_label is None:
                logger.warning(f"Cannot extract option from direct response: {direct_response[:100]}...")
                if attempt < retry_count:
                    time.sleep(retry_delay)
                    continue
                else:
                    raise ValueError("Cannot extract option from model response")
            
            # Get option text
            answer_text = get_choice_text_by_label(question_data['question']['choices'], answer_label)
            
            # 2. Verify answer
            verification_result = verify_answer(question_data, answer_label)
            is_verified = verification_result.get("verified", False)
            
            # Record verification result
            verification_history.append(verification_result)
            
            # 3. If verification fails and retry chances remain, retry
            if not is_verified and attempt < retry_count:
                logger.info(f"Question {question_id} answer verification failed, preparing retry (last chance)")
                time.sleep(retry_delay)
                continue
            
            # 4. Build final result regardless of verification success (retry chances exhausted)
            result = {
                "id": question_id,
                "question": question_data['question'],
                "answer": {
                    "option": answer_label,
                    "content": answer_text
                },
                "verification": verification_result,
                "attempts": attempt + 1,
                "status": "success" if is_verified else "success_unverified"  # Distinguish verified and unverified
            }
            
            if is_verified:
                logger.info(f"Question {question_id} processed successfully, verification passed, attempts: {attempt + 1}")
            else:
                logger.warning(f"Question {question_id} processing completed but verification failed, attempts: {attempt + 1}")
            
            return result

        except Exception as e:
            if attempt < retry_count:
                logger.warning(f"Error processing question {question_id} (attempt {attempt+1}/{retry_count+1}): {e}")
                time.sleep(retry_delay)
            else:
                error_msg = f"Failed to process question: {str(e)}\n{traceback.format_exc()}"
                logger.error(error_msg)
                return {
                    "id": question_id,
                    "question": question_data.get('question', {}),
                    "error": str(e),
                    "verification_history": verification_history,
                    "attempts": attempt + 1,
                    "status": "failed"
                }

def load_processed_ids(output_file: str) -> Set[str]:
    """Load processed question ID set for checkpoint resume"""
    processed_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if 'id' in data:
                        processed_ids.add(data['id'])
                except:
                    continue
    return processed_ids

async def process_batch(batch: List[Dict], max_workers: int = 1) -> List[Dict]:
    """Process a batch of questions - modified for single-threaded processing"""
    results = []
    
    # Use single-threaded sequential processing to avoid vLLM thread conflicts
    for question in tqdm(batch, desc="Processing batch", leave=False):
        try:
            result = process_question(question)
            results.append(result)
        except Exception as e:
            logger.error(f"Exception occurred while processing question {question.get('id', 'unknown')}: {e}")
            results.append({
                "id": question.get('id', 'unknown'),
                "question": question.get('question', {}),
                "error": str(e),
                "status": "failed"
            })
    
    return results

async def process_jsonl_file(input_file: str, output_file: str, batch_size: int = 10, max_workers: int = 1, save_interval: int = 50):
    """Process JSONL file in batches with checkpoint resume, save progress every specified number of questions"""
    # Load processed question IDs
    processed_ids = load_processed_ids(output_file)
    logger.info(f"Found {len(processed_ids)} already processed questions")
    
    # Statistics
    total_questions = 0
    processed_questions = 0
    success_count = 0
    failed_count = 0
    skipped_count = 0
    verified_count = 0
    unverified_count = 0
    fact_pass_count = 0
    logic_pass_count = 0
    both_pass_count = 0
    
    # Read all questions
    all_questions = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                question_data = json.loads(line)
                total_questions += 1
                
                # Skip already processed questions
                if question_data.get('id') in processed_ids:
                    skipped_count += 1
                    continue
                    
                all_questions.append(question_data)
            except json.JSONDecodeError:
                logger.warning(f"Cannot parse JSON line: {line[:100]}...")
    
    # Create progress bar
    progress_bar = tqdm(total=len(all_questions), desc="Overall progress", unit="questions")
    
    # Create temporary results list for saving batch results
    temp_results = []
    
    # Process in batches
    for i in range(0, len(all_questions), batch_size):
        batch = all_questions[i:i+batch_size]
        
        # Process current batch
        results = await process_batch(batch, max_workers)
        
        # Add to temporary results list
        temp_results.extend(results)
        
        # Update statistics
        for result in results:
            processed_questions += 1
            
            # Update status statistics
            status = result.get('status')
            if status == 'success':
                success_count += 1
                verification = result.get('verification', {})
                if verification.get('verified', False):
                    verified_count += 1
                    both_pass_count += 1
                else:
                    unverified_count += 1
                
                # Count verification pass statistics
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
                    
            elif status == 'success_unverified':
                success_count += 1
                unverified_count += 1
                
                verification = result.get('verification', {})
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
            else:
                failed_count += 1
        
        # Update progress bar
        progress_bar.update(len(batch))
        
        # Display current statistics
        progress_bar.set_postfix({
            "Success": success_count,
            "Failed": failed_count,
            "Verified": verified_count,
            "Fact Pass": fact_pass_count,
            "Logic Pass": logic_pass_count,
            "Skipped": skipped_count
        })
        
        # Save progress every save_interval questions
        if len(temp_results) >= save_interval:
            # Write results
            with open(output_file, 'a', encoding='utf-8') as f:
                for result in temp_results:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
            
            # Clear temporary results list
            temp_results = []
            
            # Display save information
            print(f"\nProgress saved, currently processed {processed_questions} questions")
    
    # Process remaining results
    if temp_results:
        with open(output_file, 'a', encoding='utf-8') as f:
            for result in temp_results:
                f.write(json.dumps(result, ensure_ascii=False) + '\n')
        print(f"\nProgress saved, currently processed {processed_questions} questions")
    
    progress_bar.close()
    
    # Return statistics
    return {
        "total": total_questions,
        "processed": processed_questions,
        "success": success_count,
        "failed": failed_count,
        "skipped": skipped_count,
        "verified": verified_count,
        "unverified": unverified_count,
        "fact_pass_count": fact_pass_count,
        "logic_pass_count": logic_pass_count,
        "both_pass_count": both_pass_count
    }

async def main():
    input_file = 'Your dataset path'  # Input JSONL file path
    output_file = 'Your output result path'  # Output JSONL file path
    
    # Configuration parameters - modified to single-threaded
    batch_size = 50  # Batch processing question count
    max_workers = 1  # Maximum parallel worker threads (using single thread)
    save_interval = 50  # Save interval, save progress every 50 questions processed
    
    print(f"Starting to process file: {input_file}")
    print(f"Using model: {model_path}")
    print(f"Batch size: {batch_size}, Parallel threads: {max_workers}")
    print(f"Save progress every {save_interval} questions processed")
    print("Strategy: Fact verification (align with facts + optimal solution) + Logic verification (structured logical reasoning)")
    
    # Record start time
    start_time = time.time()
    
    # Process JSONL file
    stats = await process_jsonl_file(
        input_file=input_file, 
        output_file=output_file,
        batch_size=batch_size,
        max_workers=max_workers,
        save_interval=save_interval
    )
    
    # Calculate total elapsed time
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    # Output statistics
    print("\nProcessing completed!")
    print(f"Total elapsed time: {int(hours)} hours {int(minutes)} minutes {int(seconds)} seconds")
    print(f"Total questions: {stats['total']}")
    print(f"Processed: {stats['processed']}")
    print(f"Success: {stats['success']}")
    print(f"Failed: {stats['failed']}")
    print(f"Skipped: {stats['skipped']}")
    print(f"Dual verification passed: {stats['both_pass_count']}")
    print(f"Fact verification passed: {stats['fact_pass_count']}")
    print(f"Logic verification passed: {stats['logic_pass_count']}")
    print(f"Verification failed: {stats['unverified']}")
    
    # Calculate success rate and verification pass rate
    if stats['processed'] > 0:
        success_rate = stats['success'] / stats['processed'] * 100
        print(f"Success rate: {success_rate:.2f}%")
        
        if stats['success'] > 0:
            both_verification_rate = stats['both_pass_count'] / stats['success'] * 100
            fact_verification_rate = stats['fact_pass_count'] / stats['success'] * 100
            logic_verification_rate = stats['logic_pass_count'] / stats['success'] * 100
            print(f"Dual verification pass rate: {both_verification_rate:.2f}%")
            print(f"Fact verification pass rate: {fact_verification_rate:.2f}%")
            print(f"Logic verification pass rate: {logic_verification_rate:.2f}%")

if __name__ == "__main__":
    asyncio.run(main())
