import logging
import json
import asyncio
import os
import time
import re
from typing import Dict, List, Set, Optional
from tqdm import tqdm
import traceback
from vllm import LLM, SamplingParams

# Set up logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize vllm model
try:
    # Model path configuration
    model_path = "Your model path"  # Modify to your local model path
    
    logger.info(f"Loading model from: {model_path}")
    
    # Initialize model using vllm
    llm = LLM(
        model=model_path,
        tensor_parallel_size=1,  # Adjust based on number of GPUs
        trust_remote_code=True,
        max_model_len=4096,  # Adjust maximum length as needed
        gpu_memory_utilization=0.9  # GPU memory utilization
    )
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=512,
        top_p=0.9,
        stop=None
    )
    
    logger.info(f"Successfully loaded vllm model from {model_path}")
except Exception as e:
    logger.error(f"Failed to initialize vllm model: {e}")
    raise

def generate_response(prompt: str, max_tokens: int = 512) -> str:
    """Generate response using vllm"""
    try:
        # Update sampling parameters
        current_sampling_params = SamplingParams(
            temperature=0.7,
            max_tokens=max_tokens,
            top_p=0.9,
            stop=None
        )
        
        # Generate using vllm
        outputs = llm.generate([prompt], current_sampling_params)
        
        # Extract generated text
        if outputs and len(outputs) > 0:
            response = outputs[0].outputs[0].text.strip()
            return response
        else:
            logger.warning("vllm generation result is empty")
            return ""
            
    except Exception as e:
        logger.error(f"Error generating response with vllm: {e}")
        raise

def create_direct_prompt(question_data: Dict) -> str:
    """Create prompt for directly answering the question"""
    return f"""
Please analyze the following information and determine if the question is true or false:

Term: {question_data['term']}
Description: {question_data['description']}
Question: {question_data['question']}

Based on the given term and its description, please determine whether the question is true or false.
Consider the logical relationship, factual accuracy, and common sense reasoning.

**Note: Your answer must be in the format "The answer is true~[reason]" or "The answer is false~[reason]".**

For example:
- "The answer is true~Achilles was indeed a descendant of Gaia according to Greek mythology."
- "The answer is false~There is no evidence supporting this claim in the mythological sources."
"""

def create_fact_verification_prompt(question_data: Dict, predicted_answer: str) -> str:
    """Create fact verification prompt"""
    return f"""
Verify if the following analysis is factually correct based on general facts and objective laws:

Term: {question_data['term']}
Description: {question_data['description']}
Question: {question_data['question']}
Predicted Answer: {predicted_answer.upper()}

Please analyze:
1. Is this answer factually correct based on known information and established facts?
2. Does this answer align with objective laws and natural phenomena?
3. Is this answer consistent with historical facts, scientific knowledge, or well-established information?
4. Are there any factual errors or contradictions with widely accepted knowledge?

Provide your verification result in the format: "Fact verification: [PASS/FAIL]~[detailed reason explaining why the answer is factually correct or incorrect]"
"""

def create_logic_verification_prompt(question_data: Dict, predicted_answer: str) -> str:
    """Create logic verification prompt"""
    return f"""
Perform logical verification for the following analysis using structured logical reasoning:

Term: {question_data['term']}
Description: {question_data['description']}
Question: {question_data['question']}
Predicted Answer: {predicted_answer.upper()}

Please follow these steps:

1. **Define the Target**: Clearly define what proposition we are trying to verify based on the question.

2. **Identify Premises and Conclusion**: 
   - Determine which statements serve as the foundation for the argument (premises that support the conclusion)
   - Confirm the viewpoint or conclusion to be proven based on these premises (final claim)

3. **Analyze Argument Structure**: 
   - Convert the argument (question with substituted answer) into appropriate logical form
   - Determine the logical form used (propositional logic, predicate logic, syllogism, etc.)
   - Check if the structure is correct

4. **Verify Argument Validity**: 
   - Evaluate whether each premise is true
   - Check if the derivation process from premises to conclusion follows logical laws

5. **Consider Counterexamples or Special Cases**: 
   - Try to find examples that could refute the argument
   - Look for other possible ways to explain the phenomenon
   - Analyze whether these special cases affect the correctness of the conclusion

Based on your analysis, provide your verification result in the format: "Logic verification: [PASS/FAIL]~[detailed explanation of the logical analysis]"
"""

def extract_answer(response: str) -> Optional[str]:
    """Extract True/False answer from model response, uniformly convert to lowercase"""
    # Clean response text, remove code block markers
    cleaned_response = re.sub(r'```[a-zA-Z]*\n', '', response)
    cleaned_response = re.sub(r'```', '', cleaned_response)
    
    # First try to match standard format "The answer is true~" or "The answer is false~" (lowercase)
    standard_match = re.search(r'The answer is (true|false)\s*[~\-:]', cleaned_response, re.IGNORECASE)
    if standard_match:
        return standard_match.group(1).lower()
    
    # Try to match "answer = true" or "answer = false" (code format)
    code_match = re.search(r'answer\s*=\s*["\']?(true|false)["\']?', cleaned_response, re.IGNORECASE)
    if code_match:
        return code_match.group(1).lower()
    
    # Try to match "result: true" or "result: false"
    result_match = re.search(r'result\s*[:=]\s*["\']?(true|false)["\']?', cleaned_response, re.IGNORECASE)
    if result_match:
        return result_match.group(1).lower()
    
    # Try to match "conclusion: true" or "conclusion: false"
    conclusion_match = re.search(r'conclusion\s*[:=]\s*["\']?(true|false)["\']?', cleaned_response, re.IGNORECASE)
    if conclusion_match:
        return conclusion_match.group(1).lower()
    
    # Try to match "true~content" or "false~content" format (lowercase priority)
    true_match = re.search(r'true\s*[~\-:]', cleaned_response, re.IGNORECASE)
    false_match = re.search(r'false\s*[~\-:]', cleaned_response, re.IGNORECASE)
    
    if true_match and not false_match:
        return "true"
    elif false_match and not true_match:
        return "false"
    
    # Then try to match uppercase format "TRUE~content" or "FALSE~content"
    true_match_upper = re.search(r'TRUE\s*[~\-:]', cleaned_response)
    false_match_upper = re.search(r'FALSE\s*[~\-:]', cleaned_response)
    
    if true_match_upper and not false_match_upper:
        return "true"
    elif false_match_upper and not true_match_upper:
        return "false"
    
    # Try to match standalone "true" or "false" occurrences (lowercase priority)
    true_words = len(re.findall(r'\btrue\b', cleaned_response, re.IGNORECASE))
    false_words = len(re.findall(r'\bfalse\b', cleaned_response, re.IGNORECASE))
    
    if true_words > false_words and true_words > 0:
        return "true"
    elif false_words > true_words and false_words > 0:
        return "false"
    
    # Finally try to match uppercase forms
    if re.search(r'\bTRUE\b', cleaned_response) and not re.search(r'\bFALSE\b', cleaned_response):
        return "true"
    elif re.search(r'\bFALSE\b', cleaned_response) and not re.search(r'\bTRUE\b', cleaned_response):
        return "false"
    
    return None

def extract_verification_result(response: str, verification_type: str) -> tuple:
    """Extract verification result from verification response"""
    pattern = f'{verification_type} verification:\\s*(PASS|FAIL)\\s*~\\s*(.+)'
    match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
    
    if match:
        result = match.group(1).upper() == "PASS"
        reason = match.group(2).strip()
        return result, reason
    
    return False, ""

def verify_answer(question_data: Dict, predicted_answer: str) -> Dict:
    """Use fact verification and logic verification to check answer"""
    try:
        # 1. Fact verification
        fact_verification_prompt = create_fact_verification_prompt(question_data, predicted_answer)
        fact_verification_response = generate_response(fact_verification_prompt, max_tokens=800)
        fact_verified, fact_reason = extract_verification_result(fact_verification_response, "Fact")
        
        # 2. Logic verification
        logic_verification_prompt = create_logic_verification_prompt(question_data, predicted_answer)
        logic_verification_response = generate_response(logic_verification_prompt, max_tokens=1000)
        logic_verified, logic_reason = extract_verification_result(logic_verification_response, "Logic")
        
        # Comprehensive verification result: both must pass for verification success
        is_verified = fact_verified and logic_verified
        
        return {
            "verified": is_verified,
            "fact": {
                "verified": fact_verified,
                "reason": fact_reason,
                "response": fact_verification_response
            },
            "logic": {
                "verified": logic_verified,
                "reason": logic_reason,
                "response": logic_verification_response
            }
        }
    
    except Exception as e:
        logger.error(f"Error verifying answer: {e}")
        return {"verified": False, "error": str(e)}

def create_retry_prompt(question_data: Dict, verification_result: Dict) -> str:
    """Create retry prompt based on verification failure reasons"""
    # Extract verification failure reasons
    fact_verified = verification_result.get("fact", {}).get("verified", False)
    fact_reason = verification_result.get("fact", {}).get("reason", "")
    logic_verified = verification_result.get("logic", {}).get("verified", False)
    logic_reason = verification_result.get("logic", {}).get("reason", "")
    
    # Build failure reason prompt
    failure_reasons = []
    if not fact_verified and fact_reason:
        failure_reasons.append(f"- **Factual verification failed**: {fact_reason}")
    if not logic_verified and logic_reason:
        failure_reasons.append(f"- **Logical verification failed**: {logic_reason}")
    
    failure_info = "\n".join(failure_reasons) if failure_reasons else "Verification failed, please reconsider"
    
    return f"""
Please reconsider and re-analyze the following question. Your previous answer did not pass verification for the following reasons:

{failure_info}

Please think more carefully and provide a more accurate analysis:

Term: {question_data['term']}
Description: {question_data['description']}
Question: {question_data['question']}

Consider:
1. **Factual Accuracy**: Ensure your answer aligns with known facts, historical information, and objective laws
2. **Logical Reasoning**: Follow proper logical structure and valid reasoning processes
3. **Evidence-based Analysis**: Base your conclusion on solid evidence and logical derivation
4. **Consider Counterarguments**: Think about potential counterexamples or alternative explanations

**Note: Your answer must be in the format "The answer is TRUE~[reason]" or "The answer is FALSE~[reason]".**
Please ensure your reasoning is both factually accurate and logically sound.
"""

def process_question(question_data: Dict, retry_count: int = 1, retry_delay: int = 5) -> Dict:
    """Process question with retry mechanism, retry once if verification fails"""
    question_id = question_data.get('qid', 'unknown')
    
    # Record verification failure history
    verification_history = []
    
    for attempt in range(retry_count + 1):  # Maximum 2 attempts total (initial + 1 retry)
        try:
            if not question_data.get('question'):
                raise ValueError("Question data cannot be empty")

            # 1. Let model answer directly or re-answer based on verification failure information
            if attempt == 0:
                # First attempt, use standard prompt
                direct_prompt = create_direct_prompt(question_data)
                logger.info(f"Question {question_id} first attempt")
            else:
                # Retry (only one retry chance), check if there's verification history
                if verification_history:
                    # Include verification failure reasons in prompt
                    last_verification = verification_history[-1]
                    direct_prompt = create_retry_prompt(question_data, last_verification)
                    logger.info(f"Question {question_id} retry (based on verification failure reasons)")
                else:
                    # If no verification history (possibly answer extraction failed), use standard prompt
                    direct_prompt = create_direct_prompt(question_data)
                    logger.info(f"Question {question_id} retry (no verification history)")
            
            # Generate response using vllm
            direct_response = generate_response(direct_prompt)
            predicted_answer = extract_answer(direct_response)
            
            if predicted_answer is None:
                logger.warning(f"Cannot extract answer from direct response: {direct_response[:100]}...")
                if attempt < retry_count:
                    time.sleep(retry_delay)
                    continue
                else:
                    raise ValueError("Cannot extract True/False answer from model response")
            
            # 2. Verify answer
            verification_result = verify_answer(question_data, predicted_answer)
            is_verified = verification_result.get("verified", False)
            
            # Record verification result
            verification_history.append(verification_result)
            
            # 3. If verification fails and retry chances remain, retry
            if not is_verified and attempt < retry_count:
                logger.info(f"Question {question_id} answer verification failed, preparing retry (last chance)")
                time.sleep(retry_delay)
                continue
            
            # 4. Build final result regardless of verification success (retry chances exhausted)
            result = {
                "qid": question_id,
                "term": question_data['term'],
                "description": question_data['description'],
                "question": question_data['question'],
                "model_answer": predicted_answer,  # Unified lowercase true/false
                "verification": verification_result,
                "attempts": attempt + 1,
                "status": "success" if is_verified else "success_unverified"  # Distinguish verified and unverified
            }
            
            # If original data contains correct answer, add to result
            if 'answer' in question_data:
                result['ground_truth'] = str(question_data['answer']).lower()
            
            if is_verified:
                logger.info(f"Question {question_id} processed successfully, verification passed, attempts: {attempt + 1}")
            else:
                logger.warning(f"Question {question_id} processing completed but verification failed, attempts: {attempt + 1}")
            
            return result

        except Exception as e:
            if attempt < retry_count:
                logger.warning(f"Error processing question {question_id} (attempt {attempt+1}/{retry_count+1}): {e}")
                time.sleep(retry_delay)
            else:
                error_msg = f"Failed to process question: {str(e)}\n{traceback.format_exc()}"
                logger.error(error_msg)
                return {
                    "qid": question_id,
                    "term": question_data.get('term', ''),
                    "description": question_data.get('description', ''),
                    "question": question_data.get('question', ''),
                    "error": str(e),
                    "verification_history": verification_history,
                    "attempts": attempt + 1,
                    "status": "failed"
                }

def load_processed_ids(output_file: str) -> Set[str]:
    """Load processed question ID set for checkpoint resume"""
    processed_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if 'qid' in data:
                        processed_ids.add(data['qid'])
                except:
                    continue
    return processed_ids

async def process_batch(batch: List[Dict], max_workers: int = 1) -> List[Dict]:
    """Process a batch of questions - single-threaded processing"""
    results = []
    
    # Use single-threaded sequential processing to avoid vLLM thread conflicts
    for question in tqdm(batch, desc="Processing batch", leave=False):
        try:
            result = process_question(question)
            results.append(result)
        except Exception as e:
            logger.error(f"Exception occurred while processing question {question.get('qid', 'unknown')}: {e}")
            results.append({
                "qid": question.get('qid', 'unknown'),
                "term": question.get('term', ''),
                "description": question.get('description', ''),
                "question": question.get('question', ''),
                "error": str(e),
                "status": "failed"
            })
    
    return results

async def process_jsonl_file(input_file: str, output_file: str, batch_size: int = 10, max_workers: int = 1, save_interval: int = 50):
    """Process JSONL file in batches with checkpoint resume, save progress every specified number of questions"""
    # Load processed question IDs
    processed_ids = load_processed_ids(output_file)
    logger.info(f"Found {len(processed_ids)} already processed questions")
    
    # Statistics
    total_questions = 0
    processed_questions = 0
    success_count = 0
    failed_count = 0
    skipped_count = 0
    verified_count = 0
    unverified_count = 0
    true_answers = 0
    false_answers = 0
    fact_pass_count = 0
    logic_pass_count = 0
    both_pass_count = 0
    
    # Read all questions
    all_questions = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                question_data = json.loads(line)
                total_questions += 1
                
                # Skip already processed questions
                if question_data.get('qid') in processed_ids:
                    skipped_count += 1
                    continue
                    
                all_questions.append(question_data)
            except json.JSONDecodeError:
                logger.warning(f"Cannot parse JSON line: {line[:100]}...")
    
    # Create progress bar
    progress_bar = tqdm(total=len(all_questions), desc="Overall progress", unit="questions")
    
    # Create temporary results list for saving batch results
    temp_results = []
    
    # Process in batches
    for i in range(0, len(all_questions), batch_size):
        batch = all_questions[i:i+batch_size]
        
        # Process current batch
        results = await process_batch(batch, max_workers)
        
        # Add to temporary results list
        temp_results.extend(results)
        
        # Update statistics
        for result in results:
            processed_questions += 1
            
            # Update status statistics
            status = result.get('status')
            if status == 'success':
                success_count += 1
                # Check verification result
                verification = result.get('verification', {})
                if verification.get('verified', False):
                    verified_count += 1
                    both_pass_count += 1
                else:
                    unverified_count += 1
                
                # Count fact verification and logic verification pass statistics
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
                
                # Count True/False answers
                model_answer = result.get('model_answer', '').lower()
                if model_answer == 'true':
                    true_answers += 1
                elif model_answer == 'false':
                    false_answers += 1
                    
            elif status == 'success_unverified':
                success_count += 1
                unverified_count += 1
                
                # Count fact verification and logic verification pass statistics
                verification = result.get('verification', {})
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
                
                # Count True/False answers
                model_answer = result.get('model_answer', '').lower()
                if model_answer == 'true':
                    true_answers += 1
                elif model_answer == 'false':
                    false_answers += 1
            else:
                failed_count += 1
        
        # Update progress bar
        progress_bar.update(len(batch))
        
        # Display current statistics
        progress_bar.set_postfix({
            "Success": success_count,
            "Failed": failed_count,
            "Verified": verified_count,
            "Fact Pass": fact_pass_count,
            "Logic Pass": logic_pass_count,
            "TRUE": true_answers,
            "FALSE": false_answers
        })
        
        # Save progress every save_interval questions
        if len(temp_results) >= save_interval:
            # Write results
            with open(output_file, 'a', encoding='utf-8') as f:
                for result in temp_results:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
            
            # Clear temporary results list
            temp_results = []
            
            # Display save information
            print(f"\nProgress saved, currently processed {processed_questions} questions")
    
    # Process remaining results
    if temp_results:
        with open(output_file, 'a', encoding='utf-8') as f:
            for result in temp_results:
                f.write(json.dumps(result, ensure_ascii=False) + '\n')
        print(f"\nProgress saved, currently processed {processed_questions} questions")
    
    progress_bar.close()
    
    # Return statistics
    return {
        "total": total_questions,
        "processed": processed_questions,
        "success": success_count,
        "failed": failed_count,
        "skipped": skipped_count,
        "verified": verified_count,
        "unverified": unverified_count,
        "true_answers": true_answers,
        "false_answers": false_answers,
        "fact_pass_count": fact_pass_count,
        "logic_pass_count": logic_pass_count,
        "both_pass_count": both_pass_count
    }

async def main():
    input_file = 'Your dataset path'  # Input JSONL file path
    output_file = 'Your output result path'  # Output JSONL file path
    
    # Configuration parameters - modified to single-threaded
    batch_size = 50  # Batch processing question count
    max_workers = 1  # Maximum parallel worker threads (using single thread)
    save_interval = 50  # Save interval, save progress every 50 questions processed
    
    print(f"Starting to process file: {input_file}")
    print(f"Using model: {model_path}")
    print(f"Batch size: {batch_size}, Parallel threads: {max_workers}")
    print(f"Save progress every {save_interval} questions processed")
    print("New strategy: Separate fact verification and logic verification")
    
    # Record start time
    start_time = time.time()
    
    # Process JSONL file
    stats = await process_jsonl_file(
        input_file=input_file, 
        output_file=output_file,
        batch_size=batch_size,
        max_workers=max_workers,
        save_interval=save_interval
    )
    
    # Calculate total elapsed time
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    # Output statistics
    print("\nProcessing completed!")
    print(f"Total elapsed time: {int(hours)} hours {int(minutes)} minutes {int(seconds)} seconds")
    print(f"Total questions: {stats['total']}")
    print(f"Processed: {stats['processed']}")
    print(f"Success: {stats['success']}")
    print(f"Failed: {stats['failed']}")
    print(f"Skipped: {stats['skipped']}")
    print(f"Dual verification passed: {stats['both_pass_count']}")
    print(f"Fact verification passed: {stats['fact_pass_count']}")
    print(f"Logic verification passed: {stats['logic_pass_count']}")
    print(f"Verification failed: {stats['unverified']}")
    print(f"TRUE answers: {stats['true_answers']}")
    print(f"FALSE answers: {stats['false_answers']}")
    
    # Calculate success rate and verification pass rate
    if stats['processed'] > 0:
        success_rate = stats['success'] / stats['processed'] * 100
        print(f"Success rate: {success_rate:.2f}%")
        
        if stats['success'] > 0:
            both_verification_rate = stats['both_pass_count'] / stats['success'] * 100
            fact_verification_rate = stats['fact_pass_count'] / stats['success'] * 100
            logic_verification_rate = stats['logic_pass_count'] / stats['success'] * 100
            print(f"Dual verification pass rate: {both_verification_rate:.2f}%")
            print(f"Fact verification pass rate: {fact_verification_rate:.2f}%")
            print(f"Logic verification pass rate: {logic_verification_rate:.2f}%")
            
        # Calculate True/False ratio
        total_answers = stats['true_answers'] + stats['false_answers']
        if total_answers > 0:
            true_rate = stats['true_answers'] / total_answers * 100
            false_rate = stats['false_answers'] / total_answers * 100
            print(f"TRUE/FALSE ratio: {true_rate:.1f}% / {false_rate:.1f}%")

if __name__ == "__main__":
    asyncio.run(main())
