import logging
import json
import asyncio
import os
import time
import re
from typing import Dict, List, Set, Tuple, Optional
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import traceback
from vllm import LLM, SamplingParams

# Set up logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize vllm model
try:
    # Model path configuration
    model_path = "Your model path"  # Change to your local model path
    
    logger.info(f"Loading model from: {model_path}")
    
    # Initialize model with vllm
    llm = LLM(
        model=model_path,
        tensor_parallel_size=1,  # Adjust according to GPU count
        trust_remote_code=True,
        max_model_len=4096,  # Adjust max length as needed
        gpu_memory_utilization=0.9  # GPU memory utilization
    )
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=512,
        top_p=0.9,
        stop=None
    )
    
    logger.info(f"Successfully loaded vllm model from {model_path}")
except Exception as e:
    logger.error(f"Failed to initialize vllm model: {e}")
    raise

def generate_response(prompt: str, max_tokens: int = 512) -> str:
    """Generate response using vllm"""
    try:
        # Update sampling parameters
        current_sampling_params = SamplingParams(
            temperature=0.7,
            max_tokens=max_tokens,
            top_p=0.9,
            stop=None
        )
        
        # Generate using vllm
        outputs = llm.generate([prompt], current_sampling_params)
        
        # Extract generated text
        if outputs and len(outputs) > 0:
            response = outputs[0].outputs[0].text.strip()
            return response
        else:
            logger.warning("vllm generation result is empty")
            return ""
            
    except Exception as e:
        logger.error(f"Error generating response with vllm: {e}")
        raise

def create_direct_prompt(question_data: Dict) -> str:
    """Create prompt for direct answer, require brief answer"""
    knowledge = question_data.get('knowledge', '')
    question = question_data.get('question', '')
    if knowledge:
        return f"""Based on the following knowledge, please answer the question with a very brief and direct answer (1-3 words if possible):

Knowledge: {knowledge}

Question: {question}

Answer:"""
    else:
        return f"""Please answer the following question with a very brief and direct answer (1-3 words if possible):

Question: {question}

Answer:"""

def extract_short_answer(response: str) -> tuple[Optional[str], str]:
    """Extract brief answer from model response"""
    full_response = response.strip()
    if not full_response:
        return None, ""
    cleaned_response = re.sub(r'\s+', ' ', full_response).strip()
    if len(cleaned_response) <= 100:
        return cleaned_response, full_response
    lines = full_response.split('\n')
    if lines:
        first_line = re.sub(r'^(Answer:|The answer is:?|Response:?)\s*', '', lines[0], flags=re.IGNORECASE).strip()
        if first_line:
            return first_line, full_response
    sentences = re.split(r'[.!?]', cleaned_response)
    if sentences and sentences[0].strip():
        answer = re.sub(r'^(Answer:|The answer is:?|Response:?)\s*', '', sentences[0], flags=re.IGNORECASE).strip()
        return answer, full_response
    return cleaned_response[:50].strip(), full_response

def create_fact_verification_prompt(question_data: Dict, model_answer: str) -> str:
    """Create fact verification prompt, adapted to new data structure"""
    return f"""
Verify if the following analysis is factually correct and represents the optimal solution:

Knowledge: {question_data.get('knowledge', '')}
Question: {question_data.get('question', '')}
Model Answer: {model_answer}

Please analyze:

**Part 1: Factual Correctness**
1. Substitute the model answer into the question context
2. Analyze whether this answer aligns with general facts and objective laws
3. Check if this answer is consistent with established knowledge and reality

**Part 2: Optimal Solution Analysis**
1. Compare the model answer with all available information
2. Determine if this answer is the most appropriate choice
3. When multiple valid explanations exist, verify if this is the simplest, most direct, and objective answer
4. For human-related questions, consider whether physiological states are more direct and objective than emotional states

Provide your verification result in the format: "Fact verification: [PASS/FAIL]~[detailed reason explaining both factual correctness and optimality analysis]"
"""

def create_logic_verification_prompt(question_data: Dict, model_answer: str) -> str:
    """Create logic verification prompt, adapted to new data structure"""
    return f"""
Perform logical verification for the following analysis using structured logical reasoning:

Knowledge: {question_data.get('knowledge', '')}
Question: {question_data.get('question', '')}
Model Answer: {model_answer}

Please follow these steps:

**Step 1: Define the Target**
Clearly define what proposition we are trying to verify based on the question.

**Step 2: Identify Premises and Conclusion**
- Determine which statements serve as the foundation for the argument (premises that support the conclusion)
- Confirm the viewpoint or conclusion to be proven based on these premises (final claim)

**Step 3: Analyze Argument Structure**
- Convert the argument (question with substituted answer) into appropriate logical form
- Determine the logical form used (propositional logic, predicate logic, syllogism, etc.)
- Check if the structure is correct

**Step 4: Verify Argument Validity**
- Evaluate whether each premise is true
- Check if the derivation process from premises to conclusion follows logical laws

**Step 5: Consider Counterexamples or Special Cases**
- Try to find examples that could refute the argument
- Look for other possible ways to explain the phenomenon
- Analyze whether these special cases affect the correctness of the conclusion

Based on your analysis, provide your verification result in the format: "Logic verification: [PASS/FAIL]~[detailed explanation of the logical analysis]"
"""

def extract_verification_result(response: str, verification_type: str) -> tuple:
    """Extract verification result from verification response"""
    pattern = f'{verification_type} verification:\\s*(PASS|FAIL)\\s*~\\s*(.+)'
    match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
    if match:
        result = match.group(1).upper() == "PASS"
        reason = match.group(2).strip()
        return result, reason
    return False, ""

def verify_answer(question_data: Dict, model_answer: str) -> Dict:
    """Check direct answer using dual verification method"""
    try:
        fact_verification_response = generate_response(
            create_fact_verification_prompt(question_data, model_answer), 
            max_tokens=800
        )
        fact_verified, fact_reason = extract_verification_result(fact_verification_response, "Fact")
        logic_verification_response = generate_response(
            create_logic_verification_prompt(question_data, model_answer), 
            max_tokens=1000
        )
        logic_verified, logic_reason = extract_verification_result(logic_verification_response, "Logic")
        is_verified = fact_verified and logic_verified
        return {
            "verified": is_verified,
            "fact": {
                "verified": fact_verified,
                "reason": fact_reason,
                "response": fact_verification_response
            },
            "logic": {
                "verified": logic_verified,
                "reason": logic_reason,
                "response": logic_verification_response
            }
        }
    except Exception as e:
        logger.error(f"Error verifying answer: {e}")
        return {"verified": False, "error": str(e)}

def process_question(question_data: Dict, retry_count: int = 1, retry_delay: int = 5) -> Dict:
    """Process question with retry mechanism, adapted to new data structure"""
    question_id = question_data.get('id', 'unknown')
    verification_history = []
    for attempt in range(retry_count + 1):
        try:
            if not question_data.get('question'):
                raise ValueError("Question data cannot be empty")
            if attempt == 0:
                direct_prompt = create_direct_prompt(question_data)
            else:
                last_verification = verification_history[-1]
                # Optional: Generate more detailed retry prompt based on verification failure reason
                direct_prompt = create_direct_prompt(question_data)
            direct_response = generate_response(direct_prompt)
            model_answer, answer_content = extract_short_answer(direct_response)
            if model_answer is None or not model_answer:
                if attempt < retry_count:
                    time.sleep(retry_delay)
                    continue
                else:
                    raise ValueError("Cannot extract brief answer from model response")
            verification_result = verify_answer(question_data, model_answer)
            is_verified = verification_result.get("verified", False)
            verification_history.append(verification_result)
            if not is_verified and attempt < retry_count:
                time.sleep(retry_delay)
                continue
            result = {
                "id": question_id,
                "knowledge": question_data.get('knowledge', ''),
                "question": question_data.get('question', ''),
                "right_answer": question_data.get('right_answer', ''),
                "hallucinated_answer": question_data.get('hallucinated_answer', ''),
                "model_answer": model_answer,
                "answer_content": answer_content,
                "verification": verification_result,
                "attempts": attempt + 1,
                "status": "success" if is_verified else "success_unverified"
            }
            return result
        except Exception as e:
            if attempt < retry_count:
                time.sleep(retry_delay)
            else:
                return {
                    "id": question_id,
                    "knowledge": question_data.get('knowledge', ''),
                    "question": question_data.get('question', ''),
                    "right_answer": question_data.get('right_answer', ''),
                    "hallucinated_answer": question_data.get('hallucinated_answer', ''),
                    "error": str(e),
                    "verification_history": verification_history,
                    "attempts": attempt + 1,
                    "status": "failed"
                }

def load_processed_ids(output_file: str) -> Set[str]:
    """Load processed question ID set for resume support"""
    processed_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    if 'id' in data:
                        processed_ids.add(data['id'])
                except:
                    continue
    return processed_ids

async def process_batch(batch: List[Dict], max_workers: int = 1) -> List[Dict]:
    """Process a batch of questions - single-threaded"""
    results = []
    
    # Use single-threaded sequential processing to avoid vLLM thread conflicts
    for question in tqdm(batch, desc="Batch Processing", leave=False):
        try:
            result = process_question(question)
            results.append(result)
        except Exception as e:
            logger.error(f"Exception occurred while processing question {question.get('id', 'unknown')}: {e}")
            results.append({
                "id": question.get('id', 'unknown'),
                "question": question.get('question', {}),
                "error": str(e),
                "status": "failed"
            })
    
    return results

async def process_jsonl_file(input_file: str, output_file: str, batch_size: int = 10, max_workers: int = 1, save_interval: int = 50):
    """Process JSONL file in batches and support resume, save progress every specified number of questions"""
    # Load processed question IDs
    processed_ids = load_processed_ids(output_file)
    logger.info(f"Found {len(processed_ids)} processed questions")
    
    # Statistics
    total_questions = 0
    processed_questions = 0
    success_count = 0
    failed_count = 0
    skipped_count = 0
    verified_count = 0
    unverified_count = 0
    fact_pass_count = 0
    logic_pass_count = 0
    both_pass_count = 0
    
    # Read all questions
    all_questions = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                question_data = json.loads(line)
                total_questions += 1
                
                # Skip processed questions
                if question_data.get('id') in processed_ids:
                    skipped_count += 1
                    continue
                    
                all_questions.append(question_data)
            except json.JSONDecodeError:
                logger.warning(f"Failed to parse JSON line: {line[:100]}...")
    
    # Create progress bar
    progress_bar = tqdm(total=len(all_questions), desc="Overall Progress", unit="questions")
    
    # Create temporary result list for saving each batch
    temp_results = []
    
    # Process in batches
    for i in range(0, len(all_questions), batch_size):
        batch = all_questions[i:i+batch_size]
        
        # Process current batch
        results = await process_batch(batch, max_workers)
        
        # Add to temporary result list
        temp_results.extend(results)
        
        # Update statistics
        for result in results:
            processed_questions += 1
            
            # Update status statistics
            status = result.get('status')
            if status == 'success':
                success_count += 1
                verification = result.get('verification', {})
                if verification.get('verified', False):
                    verified_count += 1
                    both_pass_count += 1
                else:
                    unverified_count += 1
                
                # Count verification pass
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
                    
            elif status == 'success_unverified':
                success_count += 1
                unverified_count += 1
                
                verification = result.get('verification', {})
                fact_verified = verification.get('fact', {}).get('verified', False)
                logic_verified = verification.get('logic', {}).get('verified', False)
                
                if fact_verified:
                    fact_pass_count += 1
                if logic_verified:
                    logic_pass_count += 1
            else:
                failed_count += 1
        
        # Update progress bar
        progress_bar.update(len(batch))
        
        # Show current statistics
        progress_bar.set_postfix({
            "Success": success_count,
            "Failed": failed_count,
            "Verified": verified_count,
            "Fact Pass": fact_pass_count,
            "Logic Pass": logic_pass_count,
            "Skipped": skipped_count
        })
        
        # Save progress every save_interval questions
        if len(temp_results) >= save_interval:
            # Write results
            with open(output_file, 'a', encoding='utf-8') as f:
                for result in temp_results:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
            
            # Clear temporary result list
            temp_results = []
            
            # Show save info
            print(f"\nProgress saved, {processed_questions} questions processed so far")
    
    # Process remaining results
    if temp_results:
        with open(output_file, 'a', encoding='utf-8') as f:
            for result in temp_results:
                f.write(json.dumps(result, ensure_ascii=False) + '\n')
        print(f"\nProgress saved, {processed_questions} questions processed so far")
    
    progress_bar.close()
    
    # Return statistics
    return {
        "total": total_questions,
        "processed": processed_questions,
        "success": success_count,
        "failed": failed_count,
        "skipped": skipped_count,
        "verified": verified_count,
        "unverified": unverified_count,
        "fact_pass_count": fact_pass_count,
        "logic_pass_count": logic_pass_count,
        "both_pass_count": both_pass_count
    }

async def main():
    input_file = 'Your dataset path'  # Input JSONL file path
    output_file = 'Your output result path'  # Output JSONL file path
    
    # Configuration - single thread
    batch_size = 50  # Number of questions per batch
    max_workers = 1  # Max parallel worker threads (single thread)
    save_interval = 50  # Save interval, save progress every 50 questions
    
    print(f"Start processing file: {input_file}")
    print(f"Using model: {model_path}")
    print(f"Batch size: {batch_size}, Parallel threads: {max_workers}")
    print(f"Save progress every {save_interval} questions")
    print("Strategy: Fact verification (factual + optimal solution) + Logic verification (structured logical reasoning)")
    
    # Record start time
    start_time = time.time()
    
    # Process JSONL file
    stats = await process_jsonl_file(
        input_file=input_file, 
        output_file=output_file,
        batch_size=batch_size,
        max_workers=max_workers,
        save_interval=save_interval
    )
    
    # Calculate total time
    elapsed_time = time.time() - start_time
    hours, remainder = divmod(elapsed_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    
    # Output statistics
    print("\nProcessing completed!")
    print(f"Total time: {int(hours)}h {int(minutes)}m {int(seconds)}s")
    print(f"Total questions: {stats['total']}")
    print(f"Processed: {stats['processed']}")
    print(f"Success: {stats['success']}")
    print(f"Failed: {stats['failed']}")
    print(f"Skipped: {stats['skipped']}")
    print(f"Dual verification passed: {stats['both_pass_count']}")
    print(f"Fact verification passed: {stats['fact_pass_count']}")
    print(f"Logic verification passed: {stats['logic_pass_count']}")
    print(f"Verification failed: {stats['unverified']}")
    
    # Calculate success rate and verification pass rate
    if stats['processed'] > 0:
        success_rate = stats['success'] / stats['processed'] * 100
        print(f"Success rate: {success_rate:.2f}%")
        
        if stats['success'] > 0:
            both_verification_rate = stats['both_pass_count'] / stats['success'] * 100
            fact_verification_rate = stats['fact_pass_count'] / stats['success'] * 100
            logic_verification_rate = stats['logic_pass_count'] / stats['success'] * 100
            print(f"Dual verification pass rate: {both_verification_rate:.2f}%")
            print(f"Fact verification pass rate: {fact_verification_rate:.2f}%")
            print(f"Logic verification pass rate: {logic_verification_rate:.2f}%")

if __name__ == "__main__":
    asyncio.run(main())