#!/usr/bin/env python3
"""
Quick multi-threaded test script for Pipeline4.
This script processes a small number of instructions in parallel for quick testing.
"""

import sys
import os
import json
import time
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
import threading
import static_jsonl
# Add the current directory to Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from pipeline4_3 import Pipeline4
import pipeline4_3
def process_single_item_quick(item_data, index, total, results_list, results_lock, progress_lock, completed_count, total_count):
    """Quick processing function for a single item."""
    thread_id = threading.current_thread().ident
    
    try:
        #print(f"[Thread-{thread_id}] Processing item {index+1}/{total}: {item_data['instruction'][:50]}...")
        
        # Create pipeline instance for this thread
        pipeline = Pipeline4()

        
        # Process the item
        result = pipeline.process_single_item(item_data, index, total)
        
        # Thread-safe result storage
        with results_lock:
            results_list.append(result)
            completed_count[0] += 1
            
        # Thread-safe progress update
        with progress_lock:
            progress = (completed_count[0] / total_count[0]) * 100
            #print(f"[Thread-{thread_id}] ✅ Item {index+1} completed. Progress: {completed_count[0]}/{total_count[0]} ({progress:.1f}%)")
            
        return result
        
    except Exception as e:
        #print(f"[Thread-{thread_id}] ❌ Error processing item {index+1}: {e}")
        
        # Create error result
        error_result = {
            "index": index,
            "instruction": item_data.get("instruction", ""),
            "error": str(e),
            "success": False,
            "timestamp": datetime.now().isoformat()
        }
        
        with results_lock:
            results_list.append(error_result)
            completed_count[0] += 1
            
        return error_result

def load_real_data():
    """Load real instruction data."""
    try:
        #instruction_file = "instruction_image_mapping.json"
        print(pipeline4_3.INSTRUCTION_FILE)
        instruction_file = pipeline4_3.INSTRUCTION_FILE
        with open(instruction_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        instructions = []
        for item in data:
            if "instruction" in item:
                instructions.append(item)
        
        print(f"✅ Loaded {len(instructions)} real instructions")
        return instructions
        
    except Exception as e:
        print(f"❌ Failed to load real data: {e}")
        return []

def main():
    """Main function to run quick multi-threaded testing."""
    print("=" * 80)
    print("Pipeline4 Quick Multi-Threaded Testing")
    print("=" * 80)
    
    # Load real data
    instructions = load_real_data()
   


    if not instructions:
        print("❌ No instructions loaded. Exiting.")
        return

    # Configuration
    MAX_ITEMS = 10 # Quick test with 200 items
    MAX_WORKERS = 1 # Number of parallel threads
    #TODO： 
    # Limit items
    instructions = instructions[:MAX_ITEMS]
    
    print(f"🚀 Quick testing with {len(instructions)} instructions")
    print(f"🔧 Using {MAX_WORKERS} worker threads")
    print()
    
    # Thread-safe variables
    results_list = []
    results_lock = Lock()
    progress_lock = Lock()
    completed_count = [0]
    total_count = [len(instructions)]
    
    start_time = time.time()
    process_file = f"temp/{pipeline4_3.benchmark}_ours_{pipeline4_3.TIMESTAMP}__.jsonl"
    # Create thread pool and submit tasks
    output_lock = threading.Lock()


    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Submit all tasks
        future_to_index = {
            executor.submit(process_single_item_quick, item, i, len(instructions), 
                          results_list, results_lock, progress_lock, completed_count, total_count): i 
            for i, item in enumerate(instructions)
        }
        
        # Process completed tasks
        for future in as_completed(future_to_index):
            index = future_to_index[future]
            try:
                result = future.result()
                if result is not None:
                    with output_lock:
                        instruction, compliance_score ,compliance_reason,api_response= result.get("instruction", "N/A"), result.get("compliance_score", "N/A"), result.get("compliance_reason", "N/A"), result.get("api_response", "N/A")
                        with open(process_file, "a", encoding="utf-8") as f:
                            f.write(json.dumps(result, ensure_ascii=False) +"\n")
            except Exception as e:
                print(f"❌ Thread execution error for item {index+1}: {e}")
    
    end_time = time.time()
    total_time = end_time - start_time
    
    print()
    print("=" * 80)
    print("QUICK TESTING COMPLETED")
    print("=" * 80)
    
    # Sort results by index
    results_list.sort(key=lambda x: x.get("index", 0))
    
    # Calculate statistics
    successful_results = [r for r in results_list if r.get("success", True) and "error" not in r]
    failed_results = [r for r in results_list if not r.get("success", True) or "error" in r]
    compliance_scores = []
    compliance_scores = [r.get("compliance_score", 0) for r in successful_results if "compliance_score" in r]
    non_compliant_count = len([s for s in compliance_scores if s is not None and s > 0])
    
    print(f"⏱️  Total processing time: {total_time:.2f} seconds")
    print(f"📊 Total items processed: {len(results_list)}")
    print(f"✅ Successful items: {len(successful_results)}")
    print(f"❌ Failed items: {len(failed_results)}")
  #调用static_jsonl.py
    static_jsonl.main(process_file)
    # Save results
    timestamp = datetime.now().strftime("%H%M%S")

    results_file = f"./new/result_{pipeline4_3.benchmark}/{pipeline4_3.benchmark}_multithreaded_{pipeline4_3.Attacked_model}{timestamp}.json"
    os.makedirs(os.path.dirname(results_file), exist_ok=True)

    final_results = {
        "metadata": {
            "attacked_model": pipeline4_3.Attacked_model,
            "model": pipeline4_3.MODEL,
            "FIXED_IMAGE_COUNT": pipeline4_3.FIXED_IMAGE_COUNT,
            "WORD_COUNT": pipeline4_3.WORD_COUNT,
            "total_items": len(results_list),
            "successful_items": len(successful_results),
            "failed_items": len(failed_results),
            "max_workers": MAX_WORKERS,
            "processing_time_seconds": total_time,
            "timestamp": datetime.now().isoformat(),
            "scenario": pipeline4_3.scenario
            #TODO
        },
        "statistics": {
            "average_compliance_score": sum(compliance_scores)/len(compliance_scores) if compliance_scores else 0,
            "non_compliant_count": non_compliant_count,
            "total_compliant_count": len(compliance_scores) - non_compliant_count,
            "score_range": [min(compliance_scores), max(compliance_scores)] if compliance_scores else [0, 0]
        },
        "results": results_list
    }
    
    try:
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(final_results, f, indent=2, ensure_ascii=False)
        print(f"💾 Results saved to: {results_file}")
    except Exception as e:
        print(f"❌ Failed to save results: {e}")
    
    print()
    print("🎉 Quick multi-threaded testing completed!")
    
    # Print detailed results
    print("\n📊 Detailed Results:")
    for i, result in enumerate(results_list):
        if "error" in result:
            print(f"  {i+1}. ❌ ERROR: {result['error']}")
        else:
            score = result.get("compliance_score", "N/A")
            keywords = result.get("keywords", [])
            print(f"  {i+1}. Score: {score}, Keywords: {keywords}")

if __name__ == "__main__":
    main()
