# Modified run.py to support multi-round CoT with quality-based path selection
# Added print output capture functionality

import asyncio
import json
from uuid import uuid4
from langchain_core.messages import HumanMessage
from Med_agent.graph import graph
from PIL import Image
import os
import argparse
import base64
import datetime
import random
import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any
import sys
import io
from contextlib import redirect_stdout, redirect_stderr

import re
import glob

def extract_sample_id_from_log_filename(log_filename):
    pattern = r'multi_cot_log_\d{8}_\d{6}_sample_([^_]+)_attempt_(\d+)_k(\d+)\.txt'
    match = re.search(pattern, log_filename)
    if match:
        sample_id = match.group(1)
        attempt_num = int(match.group(2))
        k_value = int(match.group(3))
        return sample_id, attempt_num, k_value
    return None, None, None

def scan_existing_logs(log_dir, k_value):
    if not log_dir or not os.path.exists(log_dir):
        return set(), {}
    
    processed_samples = set()
    log_details = {}
    
    log_pattern = os.path.join(log_dir, "multi_cot_log_*.txt")
    log_files = glob.glob(log_pattern)
    
    print(f"Found {len(log_files)} existing log files in {log_dir}")
    
    for log_file in log_files:
        filename = os.path.basename(log_file)
        sample_id, attempt_num, log_k_value = extract_sample_id_from_log_filename(filename)
        
        if sample_id and log_k_value == k_value:  
            processed_samples.add(sample_id)
            
            if sample_id not in log_details:
                log_details[sample_id] = []
            
            log_details[sample_id].append({
                'log_file': log_file,
                'attempt_num': attempt_num,
                'k_value': log_k_value,
                'timestamp': os.path.getmtime(log_file)
            })
    
    for sample_id in log_details:
        log_details[sample_id].sort(key=lambda x: x['timestamp'], reverse=True)
    
    return processed_samples, log_details

def get_sample_id_from_data(sample):
    if 'id' in sample:
        return str(sample['id'])
    
   
    if 'question' in sample:
        import hashlib
        question_hash = hashlib.md5(sample['question'].encode('utf-8')).hexdigest()[:8]
        return f"hash_{question_hash}"
    
    return None

def filter_unprocessed_indices(data, indices, log_dir, k_value):
    if not log_dir:
        return indices
    
    processed_samples, log_details = scan_existing_logs(log_dir, k_value)
    
    if not processed_samples:
        print("No existing logs found, will process all samples")
        return indices
    
    unprocessed_indices = []
    skipped_count = 0
    
    for idx in indices:
        if idx < len(data):
            sample = data[idx]
            sample_id = get_sample_id_from_data(sample)
            
            if sample_id and sample_id in processed_samples:
                print(f"Skipping sample {idx} (ID: {sample_id}) - already processed")
                skipped_count += 1
                
                if sample_id in log_details:
                    latest_log = log_details[sample_id][0]
                    latest_time = datetime.datetime.fromtimestamp(latest_log['timestamp'])
                    print(f"  Latest log: {os.path.basename(latest_log['log_file'])} ({latest_time.strftime('%Y-%m-%d %H:%M:%S')})")
            else:
                unprocessed_indices.append(idx)
    
    print(f"Found {len(processed_samples)} already processed samples (k={k_value})")
    print(f"Skipped {skipped_count} samples, will process {len(unprocessed_indices)} new samples")
    
    return unprocessed_indices

class PrintCapture:
    def __init__(self, output_file):
        self.output_file = output_file
        self.original_stdout = None
        self.original_stderr = None
        
    def __enter__(self):
        self.original_stdout = sys.stdout
        self.original_stderr = sys.stderr
        
        class TeeOutput:
            def __init__(self, file_obj, original_stream):
                self.file_obj = file_obj
                self.original_stream = original_stream
                
            def write(self, text):
                self.original_stream.write(text)
                self.original_stream.flush()
                
                try:
                    self.file_obj.write(text)
                    self.file_obj.flush()
                except:
                    pass
                    
            def flush(self):
                self.original_stream.flush()
                try:
                    self.file_obj.flush()
                except:
                    pass
        
        self.file_obj = open(self.output_file, 'w', encoding='utf-8')
        
        sys.stdout = TeeOutput(self.file_obj, self.original_stdout)
        sys.stderr = TeeOutput(self.file_obj, self.original_stderr)
        
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self.original_stdout
        sys.stderr = self.original_stderr
        
        if hasattr(self, 'file_obj'):
            self.file_obj.close()

def log_node_result(log_file_path: str, node_name: str, node_result: Any, timestamp: str = None):
    """
    Dedicated function to log complete node results with proper formatting
    """
    if timestamp is None:
        timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
    
    with open(log_file_path, "a", encoding="utf-8") as log_file:
        log_file.write(f"\n{'='*80}\n")
        log_file.write(f"NODE: {node_name}\n")
        log_file.write(f"TIMESTAMP: {timestamp}\n")
        log_file.write(f"RESULT TYPE: {type(node_result).__name__}\n")
        log_file.write(f"{'='*80}\n\n")
        
        # 1. Raw result serialization
        log_file.write(">>> RAW RESULT (JSON):\n")
        try:
            if hasattr(node_result, 'model_dump'):
                # Pydantic model
                serialized = json.dumps(node_result.model_dump(), indent=2, ensure_ascii=False, default=str)
            elif isinstance(node_result, dict):
                # Dictionary
                serialized = json.dumps(node_result, indent=2, ensure_ascii=False, default=str)
            else:
                # Other types - convert to dict if possible
                if hasattr(node_result, '__dict__'):
                    serialized = json.dumps(node_result.__dict__, indent=2, ensure_ascii=False, default=str)
                else:
                    serialized = str(node_result)
            
            log_file.write(serialized)
            
        except Exception as e:
            log_file.write(f"ERROR SERIALIZING: {e}\n")
            log_file.write(f"RAW STRING: {str(node_result)[:2000]}{'...' if len(str(node_result)) > 2000 else ''}\n")
        
        log_file.write(f"\n\n>>> END RAW RESULT\n\n")
        
        # 2. Node-specific parsed information
        log_file.write(">>> PARSED INFORMATION:\n")
        
        if node_name == "process_patient_info":
            log_parsed_patient_info(log_file, node_result)
        elif node_name == "generate_k_independent_cot_paths":
            log_parsed_cot_paths(log_file, node_result)
        elif node_name == "process_all_paths_parallel":
            log_parsed_parallel_results(log_file, node_result)
        elif node_name == "select_best_path_by_quality":
            log_parsed_path_selection(log_file, node_result)
        elif node_name == "generate_final_k_path_result":
            log_parsed_final_result(log_file, node_result)
        elif node_name in ["domain_expert_analysis", "radiologist_analysis"]:
            log_parsed_expert_analysis(log_file, node_result, node_name)
        elif node_name == "coordinator_decision":
            log_parsed_coordinator_decision(log_file, node_result)
        elif node_name == "web_search":
            log_parsed_web_search(log_file, node_result)
        elif node_name == "diagnose":
            log_parsed_diagnosis(log_file, node_result)
        else:
            log_file.write("Generic node - no specific parser available\n")
        
        log_file.write(f"\n>>> END PARSED INFORMATION\n")
        log_file.write(f"{'='*80}\n\n")


def log_parsed_patient_info(log_file, result):
    if isinstance(result, dict):
        if 'patient_info' in result:
            patient_info = result['patient_info']
            if isinstance(patient_info, dict):
                log_file.write(f"Clinical Text Length: {len(patient_info.get('clinical_text', ''))}\n")
                log_file.write(f"Medical Images: {len(patient_info.get('medical_images', []))}\n")
            else:
                try:
                    clinical_text = getattr(patient_info, 'clinical_text', '')
                    medical_images = getattr(patient_info, 'medical_images', [])
                    log_file.write(f"Clinical Text Length: {len(clinical_text)}\n")
                    log_file.write(f"Medical Images: {len(medical_images)}\n")
                except AttributeError:
                    log_file.write(f"Patient Info: {type(patient_info).__name__} object\n")
        
        if 'image_contents' in result:
            log_file.write(f"Image Contents: {len(result['image_contents'])}\n")
        
        if 'diagnostic_options' in result:
            options = result['diagnostic_options']
            log_file.write(f"Diagnostic Options: {list(options.keys())}\n")
            for key, value in options.items():
                log_file.write(f"  {key}: {value[:100]}{'...' if len(value) > 100 else ''}\n")
    else:
        try:
            if hasattr(result, 'patient_info'):
                patient_info = result.patient_info
                if hasattr(patient_info, 'clinical_text'):
                    log_file.write(f"Clinical Text Length: {len(patient_info.clinical_text)}\n")
                if hasattr(patient_info, 'medical_images'):
                    log_file.write(f"Medical Images: {len(patient_info.medical_images)}\n")
            
            if hasattr(result, 'image_contents'):
                log_file.write(f"Image Contents: {len(result.image_contents)}\n")
                
            if hasattr(result, 'diagnostic_options'):
                options = result.diagnostic_options
                log_file.write(f"Diagnostic Options: {list(options.keys())}\n")
                for key, value in options.items():
                    log_file.write(f"  {key}: {value[:100]}{'...' if len(value) > 100 else ''}\n")
        except AttributeError as e:
            log_file.write(f"Error parsing patient info: {e}\n")


def log_parsed_cot_paths(log_file, result):
    """Parse and log CoT path generation results"""
    if isinstance(result, dict) and 'cot_paths' in result:
        cot_paths = result['cot_paths']
        log_file.write(f"Total CoT Paths Generated: {len(cot_paths)}\n\n")
        
        for i, path in enumerate(cot_paths):
            log_file.write(f"Path {path.get('path_id', i+1)}:\n")
            log_file.write(f"  Selected Option: {path.get('selected_option', 'N/A')}\n")
            log_file.write(f"  Confidence Level: {path.get('confidence_level', 'N/A')}\n")
            log_file.write(f"  Strategy Emphasis: {path.get('emphasis', 'N/A')}\n")
            
            reasoning = path.get('reasoning', '')
            if reasoning:
                log_file.write(f"  Reasoning Preview: {reasoning[:200]}{'...' if len(reasoning) > 200 else ''}\n")
            
            log_file.write("\n")

def log_parsed_parallel_results(log_file, result):
    path_results = []
    
    if isinstance(result, dict) and 'path_results' in result:
        path_results = result['path_results']
    elif hasattr(result, 'path_results'):
        path_results = result.path_results
    
    if path_results:
        log_file.write(f"Paths Processed: {len(path_results)}\n")
        completed_count = sum(1 for p in path_results if (p.get('processing_completed', False) if isinstance(p, dict) else getattr(p, 'processing_completed', False)))
        log_file.write(f"Completed Successfully: {completed_count}\n\n")
        
        for result_item in path_results:
            if isinstance(result_item, dict):
                path_id = result_item.get('path_id', 'Unknown')
                initial_cot_option = result_item.get('initial_cot_option', 'N/A')
                final_predicted_option = result_item.get('final_predicted_option', 'N/A')
                quality_score = result_item.get('quality_score', 0)
                reliability = result_item.get('reliability', 'Unknown')
                is_correct = result_item.get('is_correct', False)
                processing_completed = result_item.get('processing_completed', False)
                
                # Extract fact check information
                comprehensive_fact_check = result_item.get('comprehensive_fact_check', {})
                total_checks = comprehensive_fact_check.get('total_checks', 0)
                passed_checks = comprehensive_fact_check.get('passed_checks', 0)
                failed_checks = comprehensive_fact_check.get('failed_checks', 0)
                fact_check_score = comprehensive_fact_check.get('overall_fact_check_score', 0)
                
                # NEW: Extract hallucination detection information
                hallucination_summary = result_item.get('hallucination_detection_summary', {})
                total_analysis_attempts = hallucination_summary.get('total_analysis_attempts', 0)
                total_hallucination_issues = hallucination_summary.get('total_hallucination_issues', 0)
                high_severity_count = hallucination_summary.get('high_severity_count', 0)
                detailed_hallucination_logs = hallucination_summary.get('detailed_logs', {})
                
            else:
                path_id = getattr(result_item, 'path_id', 'Unknown')
                initial_cot_option = getattr(result_item, 'initial_cot_option', 'N/A')
                final_predicted_option = getattr(result_item, 'final_predicted_option', 'N/A')
                quality_score = getattr(result_item, 'quality_score', 0)
                reliability = getattr(result_item, 'reliability', 'Unknown')
                is_correct = getattr(result_item, 'is_correct', False)
                processing_completed = getattr(result_item, 'processing_completed', False)
                
                # Extract fact check information from object
                comprehensive_fact_check = getattr(result_item, 'comprehensive_fact_check', {})
                total_checks = comprehensive_fact_check.get('total_checks', 0) if isinstance(comprehensive_fact_check, dict) else 0
                passed_checks = comprehensive_fact_check.get('passed_checks', 0) if isinstance(comprehensive_fact_check, dict) else 0
                failed_checks = comprehensive_fact_check.get('failed_checks', 0) if isinstance(comprehensive_fact_check, dict) else 0
                fact_check_score = comprehensive_fact_check.get('overall_fact_check_score', 0) if isinstance(comprehensive_fact_check, dict) else 0
                
                # NEW: Extract hallucination detection information from object
                hallucination_summary = getattr(result_item, 'hallucination_detection_summary', {})
                total_analysis_attempts = hallucination_summary.get('total_analysis_attempts', 0) if isinstance(hallucination_summary, dict) else 0
                total_hallucination_issues = hallucination_summary.get('total_hallucination_issues', 0) if isinstance(hallucination_summary, dict) else 0
                high_severity_count = hallucination_summary.get('high_severity_count', 0) if isinstance(hallucination_summary, dict) else 0
                detailed_hallucination_logs = hallucination_summary.get('detailed_logs', {}) if isinstance(hallucination_summary, dict) else {}
            
            log_file.write(f"Path {path_id}:\n")
            log_file.write(f"  Initial CoT Option: {initial_cot_option}\n")
            log_file.write(f"  Final Predicted Option: {final_predicted_option}\n")
            log_file.write(f"  Quality Score: {quality_score}/100\n")
            log_file.write(f"  Reliability: {reliability}\n")
            log_file.write(f"  Is Correct: {is_correct}\n")
            log_file.write(f"  Processing Completed: {processing_completed}\n")
            
            # Add comprehensive fact check logging
            log_file.write(f"  === Fact Check Results ===\n")
            log_file.write(f"  Total Checks: {total_checks}\n")
            log_file.write(f"  Passed: {passed_checks}\n")
            log_file.write(f"  Failed: {failed_checks}\n")
            log_file.write(f"  Fact Check Score: {fact_check_score:.1f}%\n")
            
            # Log detailed fact check results if available
            if isinstance(comprehensive_fact_check, dict) and 'detailed_checks' in comprehensive_fact_check:
                detailed_checks = comprehensive_fact_check['detailed_checks']
                log_file.write(f"  Detailed Checks:\n")
                for check_name, check_result in detailed_checks.items():
                    if check_result and isinstance(check_result, dict):
                        result_status = check_result.get('result', 'UNKNOWN')
                        consistency_issues = len(check_result.get('consistency_issues', []))
                        accuracy_concerns = len(check_result.get('accuracy_concerns', []))
                        missed_info = len(check_result.get('missed_information', []))
                        total_issues = consistency_issues + accuracy_concerns + missed_info
                        log_file.write(f"    {check_name}: {result_status} (Issues: {total_issues})\n")
            
            # NEW: Add comprehensive hallucination detection logging
            log_file.write(f"  === Hallucination Detection Results ===\n")
            log_file.write(f"  Total Analysis Attempts: {total_analysis_attempts}\n")
            log_file.write(f"  Total Hallucination Issues: {total_hallucination_issues}\n")
            log_file.write(f"  High Severity Issues: {high_severity_count}\n")
            
            # Log detailed hallucination detection results by analysis type
            if detailed_hallucination_logs:
                log_file.write(f"  Detailed Hallucination Logs by Analysis:\n")
                for analysis_type, logs in detailed_hallucination_logs.items():
                    if logs:
                        log_file.write(f"    {analysis_type.replace('_', ' ').title()}:\n")
                        for i, log_entry in enumerate(logs):
                            attempt = log_entry.get('attempt', i+1)
                            detected = log_entry.get('hallucination_detected', 'NO')
                            recommendation = log_entry.get('recommendation', 'ACCEPT_AS_IS')
                            specific_actions = log_entry.get('specific_actions', [])
                            severity = log_entry.get('severity', 'LOW')
                            confidence = log_entry.get('confidence_score', 0)
                            total_issues = log_entry.get('total_issues', 0)
                            improvement_applied = log_entry.get('improvement_guidance_applied', False)
                            previous_warnings = log_entry.get('previous_warnings_count', 0)
                            
                            status_icon = "❌" if detected == "YES" else "✅"
                            improvement_icon = "🔧" if improvement_applied else ""
                            recommendation_display = recommendation.replace('_', ' ')
                            
                            log_file.write(f"      Attempt {attempt}: {status_icon} {detected} {improvement_icon} "
                                         f"(Severity: {severity}, Issues: {total_issues}, "
                                         f"Confidence: {confidence}%)\n")
                            log_file.write(f"        Recommendation: {recommendation_display}\n")
                            
                            if improvement_applied:
                                log_file.write(f"        Improvement Guidance Applied: {previous_warnings} previous warnings addressed\n")
                            
                            # Log specific actions if available
                            if specific_actions:
                                log_file.write(f"        Specific Actions Suggested:\n")
                                for action in specific_actions:
                                    log_file.write(f"          • {action}\n")
                            
                            # Log specific hallucination issues if present
                            fabricated_info = log_entry.get('fabricated_information', [])
                            unsupported_claims = log_entry.get('unsupported_claims', [])
                            contradictions = log_entry.get('contradictions', [])
                            
                            if fabricated_info:
                                log_file.write(f"        Fabricated Information: {len(fabricated_info)} items\n")
                                for info in fabricated_info[:3]:  # Show first 3 items
                                    log_file.write(f"          - {info[:100]}{'...' if len(info) > 100 else ''}\n")
                                if len(fabricated_info) > 3:
                                    log_file.write(f"          ... and {len(fabricated_info) - 3} more\n")
                            
                            if unsupported_claims:
                                log_file.write(f"        Unsupported Claims: {len(unsupported_claims)} items\n")
                                for claim in unsupported_claims[:3]:  # Show first 3 items
                                    log_file.write(f"          - {claim[:100]}{'...' if len(claim) > 100 else ''}\n")
                                if len(unsupported_claims) > 3:
                                    log_file.write(f"          ... and {len(unsupported_claims) - 3} more\n")
                            
                            if contradictions:
                                log_file.write(f"        Contradictions: {len(contradictions)} items\n")
                                for contradiction in contradictions[:3]:  # Show first 3 items
                                    log_file.write(f"          - {contradiction[:100]}{'...' if len(contradiction) > 100 else ''}\n")
                                if len(contradictions) > 3:
                                    log_file.write(f"          ... and {len(contradictions) - 3} more\n")
                    else:
                        log_file.write(f"    {analysis_type.replace('_', ' ').title()}: No logs available\n")
            else:
                log_file.write(f"    No detailed hallucination logs available\n")
            
            # NEW: Log improvement success metrics if available
            if hasattr(result_item, 'improvement_success_rate') or (isinstance(result_item, dict) and 'improvement_success_rate' in result_item):
                success_rate = result_item.get('improvement_success_rate', 0) if isinstance(result_item, dict) else getattr(result_item, 'improvement_success_rate', 0)
                log_file.write(f"  === Improvement Success Rate ===\n")
                log_file.write(f"  Success Rate: {success_rate:.1%}\n")
            
            # NEW: Log multimodal information if available
            if detailed_hallucination_logs:
                multimodal_count = 0
                total_image_checks = 0
                for analysis_type, logs in detailed_hallucination_logs.items():
                    for log_entry in logs:
                        if log_entry.get('multimodal_check', False):
                            multimodal_count += 1
                        total_image_checks += log_entry.get('image_count', 0)
                
                if multimodal_count > 0:
                    log_file.write(f"  === Multimodal Analysis Summary ===\n")
                    log_file.write(f"  Multimodal checks performed: {multimodal_count}\n")
                    log_file.write(f"  Total images analyzed: {total_image_checks}\n")
            
            log_file.write("\n")
    else:
        log_file.write("No path results found\n")

def log_parsed_path_selection(log_file, result):
    """Parse and log best path selection results"""
    if isinstance(result, dict):
        log_file.write(f"Selected Path ID: {result.get('selected_path_id', 'Unknown')}\n")
        
        if 'best_path' in result:
            best_path = result['best_path']
            log_file.write(f"\nBest Path Details:\n")
            log_file.write(f"  Path ID: {best_path.get('path_id', 'Unknown')}\n")
            log_file.write(f"  Quality Score: {best_path.get('quality_score', 0)}/100\n")
            log_file.write(f"  Final Option: {best_path.get('final_predicted_option', 'N/A')}\n")
            log_file.write(f"  Reliability: {best_path.get('reliability', 'Unknown')}\n")
            log_file.write(f"  Is Correct: {best_path.get('is_correct', False)}\n")
        
        if 'evaluated_paths' in result:
            evaluated_paths = result['evaluated_paths']
            log_file.write(f"\nAll Evaluated Paths:\n")
            for i, path in enumerate(evaluated_paths):
                path_id = path.get('path_id', i+1)
                quality = path.get('quality_score', 0)
                option = path.get('final_predicted_option', 'N/A')
                correct = "✓" if path.get('is_correct', False) else "✗"
                log_file.write(f"  Path {path_id}: Option {option}, Quality {quality}/100, {correct}\n")


def log_parsed_final_result(log_file, result):
    """Parse and log final K-path result"""
    if isinstance(result, dict):
        log_file.write(f"Final Selected Option: {result.get('selected_option', 'N/A')}\n")
        
        if 'quality_metrics' in result:
            metrics = result['quality_metrics']
            log_file.write(f"\nQuality Metrics:\n")
            log_file.write(f"  Selected Path ID: {metrics.get('selected_path_id', 'Unknown')}\n")
            log_file.write(f"  Quality Score: {metrics.get('quality_score', 0)}/100\n")
            log_file.write(f"  Reliability: {metrics.get('reliability', 'Unknown')}\n")
            log_file.write(f"  Is Correct: {metrics.get('is_correct', False)}\n")
            log_file.write(f"  Total Paths: {metrics.get('total_paths', 0)}\n")
            log_file.write(f"  Completed Paths: {metrics.get('completed_paths', 0)}\n")
        
        if 'final_diagnosis' in result:
            diagnosis = result['final_diagnosis']
            log_file.write(f"\nFinal Diagnosis Preview: {diagnosis[:300]}{'...' if len(diagnosis) > 300 else ''}\n")


def log_parsed_expert_analysis(log_file, result, expert_type):
    """Parse and log expert analysis results"""
    if isinstance(result, dict):
        if expert_type == "domain_expert_analysis" and 'domain_expert_state' in result:
            expert_state = result['domain_expert_state']
            log_file.write(f"Extracted Symptoms: {len(expert_state.get('extracted_symptoms', []))}\n")
            log_file.write(f"Clinical Findings Systems: {len(expert_state.get('clinical_findings', {}))}\n")
            log_file.write(f"Family History Items: {len(expert_state.get('family_medical_history', []))}\n")
        
        elif expert_type == "radiologist_analysis" and 'radiologist_state' in result:
            radio_state = result['radiologist_state']
            log_file.write(f"Image Findings: {len(radio_state.get('image_findings', []))}\n")
            log_file.write(f"Limitations: {len(radio_state.get('limitations', []))}\n")


def log_parsed_coordinator_decision(log_file, result):
    """Parse and log coordinator decision results"""
    if isinstance(result, dict):
        observations = result.get('all_observations', [])
        log_file.write(f"Total Observations Collected: {len(observations)}\n")
        
        if observations:
            log_file.write("Key Observations:\n")
            for i, obs in enumerate(observations[:5]):  # Show first 5
                log_file.write(f"  {i+1}. {obs[:100]}{'...' if len(obs) > 100 else ''}\n")
            if len(observations) > 5:
                log_file.write(f"  ... and {len(observations) - 5} more\n")


def log_parsed_web_search(log_file, result):
    """Parse and log web search results"""
    if isinstance(result, dict) and 'web_search_state' in result:
        search_state = result['web_search_state']
        log_file.write(f"Search Count: {search_state.get('search_count', 0)}\n")
        log_file.write(f"Combinations Searched: {len(search_state.get('searched_combinations', []))}\n")
        
        findings = search_state.get('research_findings', '')
        if findings:
            log_file.write(f"Research Findings Preview: {findings[:300]}{'...' if len(findings) > 300 else ''}\n")


def log_parsed_diagnosis(log_file, result):
    """Parse and log diagnosis results"""
    if isinstance(result, dict) and 'diagnoser_state' in result:
        diag_state = result['diagnoser_state']
        log_file.write(f"Selected Option: {diag_state.get('selected_option', 'N/A')}\n")
        
        reasoning = diag_state.get('reasoning_path', '')
        if reasoning:
            log_file.write(f"Reasoning Preview: {reasoning[:300]}{'...' if len(reasoning) > 300 else ''}\n")

def format_medxpert_sample(sample, images_dir):
    """Format MedXpertQA sample for medical diagnosis system with image processing"""
    # Process question text, separate question and options
    question_full = sample['question']
    
    # If question contains "Answer Choices:", use only the front part
    if "Answer Choices:" in question_full:
        question_text = question_full.split("Answer Choices:")[0].strip()
    else:
        question_text = question_full

    clinical_text = f"## Patient Case\n\n{question_text}\n\n"
    
    # Initialize image content list
    image_contents = []
    
    # Add possible images
    if 'images' in sample and sample['images']:
        clinical_text += "## Imaging Information\n\n"
        
        for img_name in sample['images']:
            clinical_text += f"- Image file: {img_name}\n"
            
            # Load and convert image to base64 format
            image_path = os.path.join(images_dir, img_name)
            if os.path.exists(image_path):
                try:
                    with open(image_path, "rb") as img_file:
                        img_data = base64.b64encode(img_file.read()).decode("utf-8")
                    
                    # Determine MIME type
                    mime_type = "image/jpeg"  # default
                    if img_name.lower().endswith(".png"):
                        mime_type = "image/png"
                    elif img_name.lower().endswith(".gif"):
                        mime_type = "image/gif"
                    
                    # Add to image content list
                    image_contents.append({
                        "type": "image_url",
                        "image_url": {"url": f"data:{mime_type};base64,{img_data}"}
                    })
                    
                    print(f"Successfully loaded image: {img_name}")
                except Exception as e:
                    print(f"Error loading image {img_name}: {e}")
            else:
                print(f"Image file not found: {image_path}")
    
    # Return correct answer
    correct_answer = sample.get('answer', '')

    return clinical_text, image_contents, sample['options'], correct_answer

async def run_multi_cot_diagnosis(jsonl_file, images_dir, sample_index=0, log_dir=None, attempt_num=1, k=3):
    print(f"--- Running Multi-CoT Medical Diagnosis System with Sample {sample_index} (Attempt {attempt_num}, k={k}) ---")
    
    data = load_jsonl_data(jsonl_file)
    if not data or sample_index >= len(data):
        print(f"Error: No valid data found or sample index {sample_index} out of range")
        return None
    
    sample = data[sample_index]
    
    clinical_text, image_contents, options, correct_answer = format_medxpert_sample(sample, images_dir)

    inputs = {
        "messages": [HumanMessage(content=[{"type": "text", "text": clinical_text}] + image_contents)],
        "clinical_text": clinical_text,
        "image_contents": image_contents,
        "diagnostic_options": options,
        "correct_answer": correct_answer
    }

    thread_id = str(uuid4())
    config = {
        "configurable": {
            "thread_id": thread_id,
            "max_consultation_rounds": 1,
            "cot_rounds": k  
        }
    }
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"multi_cot_log_{timestamp}_sample_{sample['id']}_attempt_{attempt_num}_k{k}.txt"
    
    if log_dir is None:
        log_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
    else:
        log_filepath = log_dir
    
    os.makedirs(log_filepath, exist_ok=True)
    
    log_filepath = os.path.join(log_filepath, log_filename)
    
    with open(log_filepath, "w", encoding="utf-8") as log_file:
        log_file.write(f"====== Multi-CoT MedXpertQA Diagnosis Record - Sample ID: {sample['id']} - Attempt {attempt_num} - k={k} ======\n\n")
        log_file.write(f"Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        log_file.write(f"Thread ID: {thread_id}\n\n")
        log_file.write("===== Initial Input Information =====\n")
        log_file.write(f"Question: {clinical_text}\n\n")
        log_file.write(f"Image Count: {len(image_contents)}\n\n")
        log_file.write("Options (provided to system at final stage):\n")
        for key, value in options.items():
            log_file.write(f"  - {key}: {value}\n")
        if 'answer' in sample:
            log_file.write(f"\nCorrect Answer: {sample['answer']}\n")
        log_file.write(f"\nCoT Rounds: {k}\n")
        log_file.write("\n===== Process Execution Log =====\n\n")
    
    print(f"\n=== Processing Multi-CoT MedXpertQA Sample (ID: {sample['id']}) - Attempt {attempt_num} - k={k} ===\n")
    print(f"\nQuestion: {clinical_text}\n")
    print(f"Log file: {log_filepath}\n")
    
    final_diagnosis = None
    path_taken = "unknown"
    quality_metrics = {}
    
    current_node = None
    
    all_node_results = {}
    
    async for event in graph.astream(inputs, config=config):
        for key, value in event.items():
            current_node = key
            timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
            
            all_node_results[key] = {
                'node_name': key,
                'raw_result': value,
                'timestamp': timestamp
            }
            
            log_node_result(log_filepath, key, value, timestamp)
            
            print(f"--- Node: {key} ---")
            
            if not final_diagnosis: 
                if hasattr(value, 'messages') and value.messages:
                    last_msg = value.messages[-1]
                    if hasattr(last_msg, 'content') and last_msg.content:
                        content = last_msg.content
                        if any(indicator in content.lower() for indicator in [
                            'final diagnosis', 'diagnosis:', 'selected option', 'option d', 'option c', 'option b', 'option a', 'option e',
                            'the patient', 'symptoms and clinical', 'imaging findings', 'diagnostic reasoning'
                        ]):
                            final_diagnosis = content
                            print(f"  Found potential final diagnosis in node '{key}': {content[:200]}...")
                elif isinstance(value, dict) and 'messages' in value and value['messages']:
                    last_msg = value['messages'][-1]
                    if hasattr(last_msg, 'content') and last_msg.content:
                        content = last_msg.content
                        if any(indicator in content.lower() for indicator in [
                            'final diagnosis', 'diagnosis:', 'selected option', 'option d', 'option c', 'option b', 'option a', 'option e',
                            'the patient', 'symptoms and clinical', 'imaging findings', 'diagnostic reasoning'
                        ]):
                            final_diagnosis = content
                            print(f"  Found potential final diagnosis in node '{key}': {content[:200]}...")
            
            if key == "generate_k_independent_cot_paths":
                if isinstance(value, dict):
                    print("  K-Path CoT Generation:")
                    if 'cot_paths' in value:
                        cot_paths = value['cot_paths']
                        print(f"    Generated Paths: {len(cot_paths)}")
                        for path in cot_paths:
                            path_id = path.get('path_id', 'Unknown')
                            selected_option = path.get('selected_option', 'N/A')
                            confidence = path.get('confidence_level', 'N/A')
                            print(f"    Path {path_id}: Option {selected_option} (Confidence: {confidence})")
                    if 'total_paths' in value:
                        print(f"    Total Paths: {value['total_paths']}")
                        
            elif key == "process_all_paths_parallel":
                if isinstance(value, dict) and 'path_results' in value:
                    path_results = value['path_results']
                    print("  Parallel Path Processing:")
                    print(f"    Paths Processed: {len(path_results)}")
                    completed = sum(1 for p in path_results if p.get('processing_completed', False))
                    print(f"    Successfully Completed: {completed}")
                    
            elif key == "select_best_path_by_quality":
                if isinstance(value, dict):
                    print("  Best Path Selection:")
                    if 'selected_path_id' in value:
                        print(f"    Selected Path: {value['selected_path_id']}")
                    if 'best_path' in value:
                        best_path = value['best_path']
                        print(f"    Quality Score: {best_path.get('quality_score', 0)}/100")
                        print(f"    Final Option: {best_path.get('final_predicted_option', 'N/A')}")
                        
            elif key == "generate_final_k_path_result":
                print("  K-Path Final Analysis Complete:")
                
                if isinstance(value, dict):
                    if 'final_diagnosis' in value:
                        if final_diagnosis is None: 
                            final_diagnosis = value['final_diagnosis']
                        print(f"  Final K-Path Diagnosis: {final_diagnosis[:300]}..." if len(final_diagnosis) > 300 else final_diagnosis[:300])
                    elif 'messages' in value and value['messages']:
                        if final_diagnosis is None:  
                            final_diagnosis = value['messages'][-1].content
                        print(f"  Final K-Path Diagnosis: {final_diagnosis[:300]}..." if len(final_diagnosis) > 300 else final_diagnosis[:300])
                    
                    if 'quality_metrics' in value:
                        quality = value['quality_metrics']
                        quality_metrics = quality  
                        print(f"  Selected Path: {quality.get('selected_path_id', 'Unknown')}/{quality.get('total_paths', 0)}")
                        print(f"  Quality Score: {quality.get('quality_score', 0)}/100")
                        print(f"  Reliability: {quality.get('reliability', 'Unknown')}")
                    
                    if 'selected_option' in value:
                        print(f"  Selected Option: {value['selected_option']}")
                
                path_taken = "k_path"
                
            else:
                if isinstance(value, dict):
                    print("  State Update: {")
                    for k, v in value.items():
                        if k in ['web_search_state', 'research_findings', 'search_results'] or 'hallucination' in k:
                            print(f"    {k}: [Content omitted for brevity]")
                        elif isinstance(v, str) and len(v) > 1000:
                            print(f"    {k}: {v[:1000]}... [Truncated]")
                        elif isinstance(v, (list, dict)) and len(str(v)) > 1000:
                            print(f"    {k}: {type(v).__name__}[Length: {len(v) if hasattr(v, '__len__') else 'Unknown'}]")
                        else:
                            print(f"    {k}: {v}")
                    print("  }")
                else:
                    value_str = str(value)
                    if len(value_str) > 1000:
                        value_str = value_str[:1000] + "... [Truncated]"
                    print(f"  State Update: {value_str}")
            print("\n")
    
    with open(log_filepath, "a", encoding="utf-8") as log_file:
        log_file.write("\n" + "="*100 + "\n")
        log_file.write("COMPLETE NODE EXECUTION SUMMARY\n")
        log_file.write("="*100 + "\n\n")
        
        for node_name, node_data in all_node_results.items():
            log_file.write(f"Node: {node_name}\n")
            log_file.write(f"Timestamp: {node_data['timestamp']}\n")
            log_file.write(f"Result Type: {type(node_data['raw_result']).__name__}\n")
            
            result = node_data['raw_result']
            if isinstance(result, dict):
                if 'cot_paths' in result and result['cot_paths'] is not None:
                    log_file.write(f"CoT Paths Generated: {len(result['cot_paths'])}\n")
                elif 'path_results' in result and result['path_results'] is not None:
                    log_file.write(f"Paths Processed: {len(result['path_results'])}\n")
                elif 'selected_path_id' in result and result['selected_path_id'] is not None:
                    log_file.write(f"Selected Path: {result['selected_path_id']}\n")
                elif 'messages' in result and result['messages'] is not None:
                    log_file.write(f"Messages Generated: {len(result['messages'])}\n")
             
            log_file.write("-" * 50 + "\n")
    
    print(f"--- Multi-CoT Medical Diagnosis System Finished (Attempt {attempt_num}) ---")
    print(f"\nQuestion: {clinical_text}\n")
    print(f"Log file: {log_filepath}\n")
    
    result = {
        'id': sample['id'],
        'sample_index': sample_index,
        'attempt_num': attempt_num,
        'k_value': k,
        'path_taken': path_taken,
        'log_file': log_filepath,
        'final_diagnosis': None,
        'predicted_answer': None,
        'correct_answer': sample.get('answer'),
        'is_correct': False,
        'options': options,
        'medical_task': sample.get('medical_task', 'Unknown'),
        'body_system': sample.get('body_system', 'Unknown'),
        'question_type': sample.get('question_type', 'Unknown'),
        'quality_metrics': quality_metrics
    }
    
    if final_diagnosis:
        print(f"\nFinal Diagnosis (Attempt {attempt_num}, Path: {path_taken}):")
        print(final_diagnosis)
        
        result['final_diagnosis'] = final_diagnosis
        
        answer_choice = None
        
        if path_taken == "k_path" and quality_metrics:
            path_info = extract_selected_path_info(quality_metrics)
            
            answer_choice = path_info.get('selected_answer')
            
            result['final_diagnosis'] = {
                'selected_path_reasoning': path_info['selected_path_reasoning'],
                'fact_check_summary': path_info['fact_check_summary'],
                'selected_path_id': path_info['selected_path_id'],
                'path_quality_score': path_info['path_quality_score'],
                'path_reliability': path_info['path_reliability'],
                'path_correctness': path_info['path_correctness'],
                'path_taken': path_taken,
                'extraction_success': path_info['extraction_success'],
                'selected_answer': answer_choice  
            }
            
            print(f"\nExtracted for JSON storage:")
            print(f"  Selected Path ID: {path_info['selected_path_id']}")
            print(f"  Selected Answer: {answer_choice}")
            print(f"  Path Quality Score: {path_info['path_quality_score']}/100")
            print(f"  Path Reliability: {path_info['path_reliability']}")
            print(f"  Reasoning Length: {len(path_info['selected_path_reasoning'])}")
            
        else:
            result['final_diagnosis'] = final_diagnosis
            answer_choice = extract_answer_from_diagnosis(final_diagnosis, options)
        
        with open(log_filepath, "a", encoding="utf-8") as log_file:
            log_file.write("===== Final Diagnosis Result =====\n\n")
            log_file.write(f"Path taken: {path_taken}\n")
            log_file.write(f"CoT rounds: {k}\n")
            if quality_metrics:
                if path_taken == "k_path":
                    log_file.write(f"Selected path ID: {quality_metrics.get('selected_path_id', 'Unknown')}\n")
                    log_file.write(f"Total paths evaluated: {quality_metrics.get('total_paths', 0)}\n")
                log_file.write(f"Quality score: {quality_metrics.get('quality_score', 0)}/100\n")
                log_file.write(f"Reliability: {quality_metrics.get('reliability', 'Unknown')}\n")
            
            log_file.write(f"\n=== JSON Storage Information ===\n")
            if isinstance(result['final_diagnosis'], dict):
                log_file.write(f"Selected Path ID: {result['final_diagnosis']['selected_path_id']}\n")
                log_file.write(f"Selected Answer: {result['final_diagnosis']['selected_answer']}\n")
                log_file.write(f"Path Quality Score: {result['final_diagnosis']['path_quality_score']}/100\n")
                log_file.write(f"Path Reliability: {result['final_diagnosis']['path_reliability']}\n")
                log_file.write(f"Selected Path Reasoning Length: {len(result['final_diagnosis']['selected_path_reasoning'])}\n")
                log_file.write(f"Extraction Success: {result['final_diagnosis']['extraction_success']}\n")
                log_file.write(f"\nFact Check Summary:\n{result['final_diagnosis']['fact_check_summary']}\n")
            else:
                log_file.write(f"Full diagnosis stored (length: {len(result['final_diagnosis'])})\n")
            
            log_file.write(f"\n=== Complete Diagnosis Report ===\n")
            log_file.write(f"{final_diagnosis}\n\n")
            
            if answer_choice:
                log_file.write(f"System selected answer: {answer_choice}\n")
                if 'answer' in sample:
                    log_file.write(f"Correct answer: {sample['answer']}\n")
                    log_file.write(f"Result: {'Correct' if answer_choice == sample['answer'] else 'Incorrect'}\n")
                    result['is_correct'] = (answer_choice == sample['answer'])
            else:
                log_file.write("System failed to extract clear answer from diagnosis\n")
        
        result['predicted_answer'] = answer_choice
        
        if answer_choice:
            print(f"\nExtracted Answer Choice (Attempt {attempt_num}): {answer_choice}")
            if 'answer' in sample:
                print(f"Correct Answer: {sample['answer']}")
                print(f"Result: {'Correct' if answer_choice == sample['answer'] else 'Incorrect'}")
        else:
            print(f"\nCould not extract specific answer choice from diagnosis (Attempt {attempt_num}).")
    else:
        print(f"System failed to generate final diagnosis (Attempt {attempt_num}).")
        
        with open(log_filepath, "a", encoding="utf-8") as log_file:
            log_file.write("===== Diagnosis Failed =====\n\n")
            log_file.write("System failed to generate final diagnosis result\n")
    
    return result

def load_jsonl_data(jsonl_file):
    """Load data from JSONL file"""
    data = []
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data

def extract_answer_from_diagnosis(diagnosis, options):
    """Extract answer option from diagnosis result - improved version"""
    import re
    
    if not diagnosis or not options:
        return None
    
    # Build mapping between option content and corresponding option letters
    options_upper = {k.upper(): v for k, v in options.items()}
    diagnosis_lower = diagnosis.lower()
    
    print(f"Extracting answer from diagnosis. Available options: {list(options_upper.keys())}")
    
    # Method 1: Priority search for clear option identifiers
    high_priority_patterns = [
        r'SELECTED OPTION:\s*([A-E])',  # Standard format
        r'(?:the\s+)?(?:answer|choice|option)\s+is\s+([A-E])',  # "The answer is A"
        r'(?:select|choose)\s+(?:option\s+)?([A-E])',  # "Select A" or "Choose option A"
        r'^([A-E])[\.\:\s]',  # Option at line beginning, like "A. " or "A:"
        r'(?:diagnosis|option)\s+([A-E])\b',  # "diagnosis A" or "option A"
    ]
    
    for pattern in high_priority_patterns:
        matches = re.findall(pattern, diagnosis, re.IGNORECASE | re.MULTILINE)
        if matches:
            extracted_answer = matches[-1].upper()  # Take the last match
            if extracted_answer in options_upper:
                print(f"Found answer using high priority pattern '{pattern}': {extracted_answer}")
                return extracted_answer
    
    # Method 2: Search for other common patterns
    medium_priority_patterns = [
        r'\b([A-E])\s*(?:is\s+(?:the\s+)?(?:correct|right|best|most\s+likely))',  # "A is the correct answer"
        r'(?:correct|right|best|most\s+likely).*?(?:answer|choice|option).*?([A-E])',  # "correct answer is A"
        r'\(([A-E])\)',  # Option in parentheses
        r'([A-E])\s*(?:\.|:|\-)\s*(?:is|was|appears|seems)',  # "A. is" or "A: appears"
    ]
    
    for pattern in medium_priority_patterns:
        matches = re.findall(pattern, diagnosis, re.IGNORECASE)
        if matches:
            extracted_answer = matches[-1].upper()
            if extracted_answer in options_upper:
                print(f"Found answer using medium priority pattern '{pattern}': {extracted_answer}")
                return extracted_answer
    
    # Method 3: Content-based matching (lower priority to avoid misidentification)
    option_mentions = {}
    for option_key, option_text in options_upper.items():
        option_text_lower = option_text.lower()
        
        # Check option content appearance in diagnosis
        if len(option_text_lower) > 10:  # Only do content matching for long options
            # Check keyword matching
            option_words = option_text_lower.split()
            key_words = [word for word in option_words if len(word) > 4 and word not in ['with', 'and', 'the', 'that', 'this', 'from', 'for']]
            
            if key_words:
                matches = sum(1 for word in key_words if word in diagnosis_lower)
                if matches >= len(key_words) * 0.5:  # At least 50% of keywords match
                    option_mentions[option_key] = matches
    
    # If there's content matching, choose the highest matching degree
    if option_mentions:
        best_option = max(option_mentions.items(), key=lambda x: x[1])
        print(f"Found answer using content matching: {best_option[0]} (matches: {best_option[1]})")
        return best_option[0]
    
    # Method 4: Last resort - search for any standalone option letters
    low_priority_patterns = [
        r'\b([A-E])\b',  # Any standalone option letter
    ]
    
    for pattern in low_priority_patterns:
        matches = re.findall(pattern, diagnosis, re.IGNORECASE)
        if matches:
            # Take the last appearing valid option
            for match in reversed(matches):
                if match.upper() in options_upper:
                    print(f"Found answer using low priority pattern '{pattern}': {match.upper()}")
                    return match.upper()
    
    print("Could not extract answer from diagnosis")
    return None

def calculate_multi_cot_accuracy(results):
    """Calculate accuracy and detailed statistics for multi-CoT results - 修复版本"""
    
    if not results:
        return {}
    
    total_questions = len(results)
    
    correct_predictions = 0
    valid_results = []
    
    for r in results:
        predicted_answer = r.get('predicted_answer')
        
        if not predicted_answer and isinstance(r.get('final_diagnosis'), dict):
            predicted_answer = r['final_diagnosis'].get('selected_answer')
        
        if predicted_answer:
            valid_results.append(r)
            if r.get('is_correct', False):
                correct_predictions += 1
    
    valid_total = len(valid_results)
    
    for r in valid_results:
        predicted_answer = r.get('predicted_answer')
        if not predicted_answer and isinstance(r.get('final_diagnosis'), dict):
            predicted_answer = r['final_diagnosis'].get('selected_answer')
        
        correct_answer = r.get('correct_answer')
        if predicted_answer and correct_answer:
            r['is_correct'] = (predicted_answer == correct_answer)
    
    correct_predictions = sum(1 for r in valid_results if r.get('is_correct', False))
    
    overall_accuracy = correct_predictions / total_questions if total_questions > 0 else 0
    valid_accuracy = correct_predictions / valid_total if valid_total > 0 else 0
    
    k_path_results = [r for r in results if r.get('path_taken') == 'k_path']
    
    k_path_accuracy = sum(1 for r in k_path_results if r.get('is_correct', False)) / len(k_path_results) if k_path_results else 0
    
    # Quality metrics analysis
    quality_scores = []
    for r in results:
        if r.get('quality_metrics'):
            quality_scores.append(r['quality_metrics'].get('quality_score', 0))
        elif isinstance(r.get('final_diagnosis'), dict):
            quality_scores.append(r['final_diagnosis'].get('path_quality_score', 0))
    
    avg_quality_score = sum(quality_scores) / len(quality_scores) if quality_scores else 0
    
    stats = {
        'overall': {
            'total': total_questions,
            'correct': correct_predictions,
            'accuracy': overall_accuracy
        },
        'valid_only': {
            'total': valid_total,
            'correct': correct_predictions, 
            'accuracy': valid_accuracy
        },
        'path_analysis': {
            'k_path': {
                'count': len(k_path_results),
                'correct': sum(1 for r in k_path_results if r.get('is_correct', False)),
                'accuracy': k_path_accuracy
            }
        },
        'quality_metrics': {
            'average_quality_score': avg_quality_score,
            'quality_scores': quality_scores
        }
    }
    
    task_stats = {}
    for result in valid_results:
        task = result.get('medical_task', 'Unknown')
        if task not in task_stats:
            task_stats[task] = {'total': 0, 'correct': 0}
        task_stats[task]['total'] += 1
        if result.get('is_correct', False):
            task_stats[task]['correct'] += 1
    
    for task in task_stats:
        task_stats[task]['accuracy'] = task_stats[task]['correct'] / task_stats[task]['total'] if task_stats[task]['total'] > 0 else 0
    
    stats['by_medical_task'] = task_stats
    
    body_stats = {}
    for result in valid_results:
        body = result.get('body_system', 'Unknown')
        if body not in body_stats:
            body_stats[body] = {'total': 0, 'correct': 0}
        body_stats[body]['total'] += 1
        if result.get('is_correct', False):
            body_stats[body]['correct'] += 1
    
    for body in body_stats:
        body_stats[body]['accuracy'] = body_stats[body]['correct'] / body_stats[body]['total'] if body_stats[body]['total'] > 0 else 0
    
    stats['by_body_system'] = body_stats
    
    type_stats = {}
    for result in valid_results:
        qtype = result.get('question_type', 'Unknown')
        if qtype not in type_stats:
            type_stats[qtype] = {'total': 0, 'correct': 0}
        type_stats[qtype]['total'] += 1
        if result.get('is_correct', False):
            type_stats[qtype]['correct'] += 1
    
    for qtype in type_stats:
        type_stats[qtype]['accuracy'] = type_stats[qtype]['correct'] / type_stats[qtype]['total'] if type_stats[qtype]['total'] > 0 else 0
    
    stats['by_question_type'] = type_stats
    
    return stats

def print_multi_cot_statistics(stats):
    """Print multi-CoT statistics results"""
    print("\n" + "="*60)
    print("Multi-CoT Test Results Statistics")
    print("="*60)
    print(f"Total samples: {stats['overall']['total']}")
    print(f"Overall accuracy: {stats['overall']['accuracy']:.2%} ({stats['overall']['correct']}/{stats['overall']['total']})")
    
    if stats['valid_only']['total'] < stats['overall']['total']:
        print(f"Valid prediction accuracy: {stats['valid_only']['accuracy']:.2%} ({stats['valid_only']['correct']}/{stats['valid_only']['total']})")
    
    # Path analysis
    path_analysis = stats.get('path_analysis', {})
    if path_analysis:
        print(f"\nPath Analysis:")
        simple = path_analysis.get('simple_path', {})
        complex = path_analysis.get('complex_path', {})
        
        if simple.get('count', 0) > 0:
            print(f"  Simple path: {simple['accuracy']:.2%} ({simple['correct']}/{simple['count']})")
        if complex.get('count', 0) > 0:
            print(f"  Complex path: {complex['accuracy']:.2%} ({complex['correct']}/{complex['count']})")
    
    # Quality metrics
    quality = stats.get('quality_metrics', {})
    if quality.get('average_quality_score', 0) > 0:
        print(f"\nQuality Metrics:")
        print(f"  Average quality score: {quality['average_quality_score']:.1f}/100")
    
    if stats.get('by_medical_task'):
        print(f"\nBy medical task classification:")
        for task, task_stats in stats['by_medical_task'].items():
            print(f"  {task}: {task_stats['accuracy']:.2%} ({task_stats['correct']}/{task_stats['total']})")
    
    if stats.get('by_body_system'):
        print(f"\nBy body system classification:")
        for body, body_stats in stats['by_body_system'].items():
            print(f"  {body}: {body_stats['accuracy']:.2%} ({body_stats['correct']}/{body_stats['total']})")
    
    if stats.get('by_question_type'):
        print(f"\nBy question type classification:")
        for qtype, type_stats in stats['by_question_type'].items():
            print(f"  {qtype}: {type_stats['accuracy']:.2%} ({type_stats['correct']}/{type_stats['total']})")

def extract_selected_path_info(quality_metrics: Dict[str, Any]) -> Dict[str, Any]:
    
    if not quality_metrics:
        return {
            'selected_path_reasoning': "No quality metrics available",
            'fact_check_summary': "No fact checking performed",
            'selected_path_id': 0,
            'selected_answer': None,
            'extraction_error': "Missing quality_metrics",
            'extraction_success': False
        }
    
    selected_path_id = quality_metrics.get('selected_path_id', 1)
    path_details = quality_metrics.get('path_details', [])
    
    selected_path_detail = None
    for path_detail in path_details:
        if path_detail.get('path_id') == selected_path_id:
            selected_path_detail = path_detail
            break
    
    if not selected_path_detail:
        return {
            'selected_path_reasoning': f"Selected path {selected_path_id} details not found",
            'fact_check_summary': "Fact checking information unavailable",
            'selected_path_id': selected_path_id,
            'selected_answer': None,
            'extraction_error': f"Path {selected_path_id} not found in path_details",
            'extraction_success': False
        }
    
    selected_answer = selected_path_detail.get('final_predicted_option', None)
    if not selected_answer:
        selected_answer = selected_path_detail.get('initial_cot_option', None)
    
    selected_path_reasoning = (
        selected_path_detail.get('final_diagnosis', '') or 
        selected_path_detail.get('initial_reasoning', '') or
        selected_path_detail.get('reasoning', '') or  
        "No reasoning available"
    )
    
    comprehensive_fact_check = selected_path_detail.get('comprehensive_fact_check', {})
    
    if comprehensive_fact_check:
        total_checks = comprehensive_fact_check.get('total_checks', 0)
        passed_checks = comprehensive_fact_check.get('passed_checks', 0)
        partial_checks = comprehensive_fact_check.get('partial_checks', 0)
        failed_checks = comprehensive_fact_check.get('failed_checks', 0)
        fact_check_score = comprehensive_fact_check.get('overall_fact_check_score', 0)
        total_consistency_issues = comprehensive_fact_check.get('total_consistency_issues', 0)
        total_accuracy_concerns = comprehensive_fact_check.get('total_accuracy_concerns', 0)
        
        fact_check_summary = f"""### Fact Checking Summary (Selected Path {selected_path_id})
- **Total Checks Performed**: {total_checks}
- **Passed**: {passed_checks}
- **Partial**: {partial_checks}
- **Failed**: {failed_checks}
- **Overall Fact Check Score**: {fact_check_score:.1f}%
- **Consistency Issues**: {total_consistency_issues}
- **Accuracy Concerns**: {total_accuracy_concerns}"""

        detailed_checks = comprehensive_fact_check.get('detailed_checks', {})
        if detailed_checks:
            fact_check_summary += "\n\n### Detailed Check Results:"
            for check_name, check_result in detailed_checks.items():
                if check_result and isinstance(check_result, dict):
                    result_status = check_result.get('result', 'UNKNOWN')
                    consistency_issues = len(check_result.get('consistency_issues', []))
                    accuracy_concerns = len(check_result.get('accuracy_concerns', []))
                    missed_info = len(check_result.get('missed_information', []))
                    total_issues = consistency_issues + accuracy_concerns + missed_info
                    
                    fact_check_summary += f"\n- **{check_name.replace('_', ' ').title()}**: {result_status}"
                    if total_issues > 0:
                        fact_check_summary += f" (Issues: {total_issues})"
    else:
        fact_check_summary = "No comprehensive fact checking performed"
    
    return {
        'selected_path_reasoning': selected_path_reasoning,
        'fact_check_summary': fact_check_summary,
        'selected_path_id': selected_path_id,
        'selected_answer': selected_answer,  
        'path_quality_score': selected_path_detail.get('quality_score', 0),
        'path_reliability': selected_path_detail.get('reliability', 'Unknown'),
        'path_correctness': selected_path_detail.get('is_correct', False),
        'extraction_success': True
    }

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Run multi-CoT medical diagnosis system for MedXpertQA samples')
    parser.add_argument('--jsonl-file', type=str, default=None, help='MedXpertQA JSONL file path')
    parser.add_argument('--images-dir', type=str, default=None, help='Directory containing images')
    parser.add_argument('--sample-size', type=int, default=1, help='Number of samples to process')
    parser.add_argument('--sample-index', type=int, default=None, help='Process specific sample index, random sampling if not specified')
    parser.add_argument('--k', type=int, default=3, help='Number of CoT reasoning rounds (k value)')
    parser.add_argument('--output-file', type=str, default="multi_cot_results.json", help='JSON file path to save results')
    parser.add_argument('--log-dir', type=str, default=None, help='Log file save directory path, default is logs folder in current directory')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducible random sampling')
    parser.add_argument('--concurrent', type=int, default=1, help='Number of samples to process concurrently, default 1 (serial)')
    
    return parser.parse_args()


async def process_multi_cot_sample_batch(indices, args):
    """Process a batch of samples with multi-CoT"""
    results = []
    for idx in indices:
        result = await run_multi_cot_diagnosis(args.jsonl_file, args.images_dir, idx, args.log_dir, 1, args.k)
        if result:
            results.append(result)
    return results

async def main():
    args = parse_args()
    
    # Set random seed for reproducible results
    random.seed(args.seed)
    
    # Load data
    print(f"Loading data file: {args.jsonl_file}")
    data = load_jsonl_data(args.jsonl_file)
    print(f"Loaded {len(data)} questions")
    
    print(f"Using multi-CoT system with k={args.k} rounds")
    
    # Determine sample indices to process
    if args.sample_index is not None:
        # Process single specific sample
        initial_indices = [args.sample_index]
        print(f"Will process specified single sample, index: {args.sample_index}")
    else:
        # Random sampling
        sample_size = min(args.sample_size, len(data))
        initial_indices = random.sample(range(len(data)), sample_size)
        print(f"Will randomly process {sample_size} samples, random seed: {args.seed}")
    
    if args.log_dir is None:
        log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
    else:
        log_dir = args.log_dir
    
    os.makedirs(log_dir, exist_ok=True)
    
    print(f"\nChecking for existing logs in: {log_dir}")
    indices = filter_unprocessed_indices(data, initial_indices, log_dir, args.k)
    # import pdb; pdb.set_trace()
    if not indices:
        print("All specified samples have already been processed!")
        
        if os.path.exists(args.output_file):
            print(f"Loading existing results from: {args.output_file}")
            with open(args.output_file, 'r', encoding='utf-8') as f:
                existing_data = json.load(f)
            
            if 'statistics' in existing_data:
                print_multi_cot_statistics(existing_data['statistics'])
        
        return
    
    all_results = []
    start_time = datetime.datetime.now()
    print(f"Start processing time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    existing_results = []
    if os.path.exists(args.output_file):
        try:
            with open(args.output_file, 'r', encoding='utf-8') as f:
                existing_data = json.load(f)
            
            if 'results' in existing_data and existing_data.get('test_mode') == f"multi_cot_k{args.k}":
                existing_results = existing_data['results']
                print(f"Loaded {len(existing_results)} existing results from output file")
        except Exception as e:
            print(f"Warning: Could not load existing results: {e}")
    
    # Progress saving function based on mode
    def save_progress(current_results, output_file):
        combined_results = existing_results + current_results
        
        if combined_results:
            # Calculate statistics based on mode
            current_stats = calculate_multi_cot_accuracy(combined_results)
            test_mode = f"multi_cot_k{args.k}"
            
            # Prepare output data
            output_data = {
                'test_mode': test_mode,
                'results': combined_results,
                'statistics': current_stats,
                'metadata': {
                    'total_processed': len(combined_results),
                    'new_samples_processed': len(current_results),
                    'existing_samples': len(existing_results),
                    'sample_size': len(initial_indices),
                    'k_value': args.k,
                    'jsonl_file': args.jsonl_file,
                    'images_dir': args.images_dir,
                    'timestamp': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                    'elapsed_time': str(datetime.datetime.now() - start_time)
                }
            }
            
            # Save to file
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(output_data, f, indent=2, ensure_ascii=False)
            
            print(f"Progress saved: processed {len(current_results)}/{len(indices)} new samples, total {len(combined_results)} samples")
            
            # Print current statistics
            print_multi_cot_statistics(current_stats)
    
    # Set up parallel processing
    if args.concurrent > 1:
        print(f"Using {args.concurrent} concurrent tasks to process samples")
        
        # Divide samples into multiple batches
        batch_size = max(1, len(indices) // args.concurrent)
        batches = [indices[i:i + batch_size] for i in range(0, len(indices), batch_size)]
        
        # Start parallel tasks
        tasks = []
        for batch in batches:
            tasks.append(process_multi_cot_sample_batch(batch, args))
        
        # Wait for all batches to complete
        for batch_results in await asyncio.gather(*tasks):
            all_results.extend(batch_results)
            
            # Save progress after each batch completion
            save_progress(all_results, args.output_file)
    else:
        print("Using serial processing for samples")
        # Serial processing of each sample
        for i, idx in enumerate(indices):
            print(f"\nProcessing sample {i+1}/{len(indices)}, index: {idx}")
            
            result = await run_multi_cot_diagnosis(args.jsonl_file, args.images_dir, idx, log_dir, 1, args.k)
            
            if result:
                all_results.append(result)
                
                # Save progress every sample (since multi-CoT is time-consuming)
                save_progress(all_results, args.output_file)
    
    end_time = datetime.datetime.now()
    elapsed_time = end_time - start_time
    print(f"\nProcessing completed time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Total time consumed: {elapsed_time}")
    
    # Calculate final statistics
    combined_results = existing_results + all_results
    final_stats = calculate_multi_cot_accuracy(combined_results)
    test_mode = f"multi_cot_k{args.k}"
    
    # Prepare final output data
    output_data = {
        'test_mode': test_mode,
        'results': combined_results,
        'statistics': final_stats,
        'metadata': {
            'total_processed': len(combined_results),
            'new_samples_processed': len(all_results),
            'existing_samples': len(existing_results),
            'sample_size': len(initial_indices),
            'k_value': args.k,
            'jsonl_file': args.jsonl_file,
            'images_dir': args.images_dir,
            'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'),
            'end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'),
            'elapsed_time': str(elapsed_time),
            'seed': args.seed
        }
    }
    
    # Save final results
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    print(f"\nFinal results saved to: {args.output_file}")
    print(f"New samples processed: {len(all_results)}")
    print(f"Total samples in output: {len(combined_results)}")
    
    # Print final statistics
    print_multi_cot_statistics(final_stats)

if __name__ == "__main__":
    asyncio.run(main())