import threading
import json
import re
from transformers import AutoModelForSequenceClassification

# Enhanced Policy Agent Prompt with competition info and detailed stats
POLICY_AGENT_PROMPT = """

You are part of a Multi-Agent System engaged in a summarization competition against several opponents, where each agent generates multiple candidate summary texts based on the original news article. 
The objective is to produce summaries with the lowest possible hallucination (i.e., highest factual consistency, reflected in higher hallucination scores) while minimizing resource usage—including API calls, token consumption, time, and review attempts, all of which are penalized when excessive. 
Your ultimate goal is to outperform your competitors by achieving the best balance between factual accuracy and resource efficiency, and to emerge as the final winner.

Your role is to make strategic decisions about what to do next. Here is your current status:
- MAS ID: {mas_id}
- Completed texts: {completed_texts}/{total_texts}
- Remaining texts: {remaining_texts}
- Current token usage: {token_usage} tokens
- API calls made: {api_calls}
- Review calls made: {review_calls}
- Current average hallucination score: {avg_score}
- Time elapsed: {time_elapsed:.2f} seconds

You need to decide whether to:
- CONTINUE: Process the next text and generate a new summary
- REVIEW: Revisit and improve the text with the highest hallucination rate (lower hallucination score means more hallucination)
- END: Stop processing if all texts have been summarized and there's no need for review

The current text summarization task is at ID {current_id} out of {max_id} texts.
The text with the highest hallucination rate so far is ID {worst_id}, with hallucination score: {worst_score}, where 0 means most hallucinated and 1 means least hallucinated.

{opponent_info}

IMPORTANT: You must output your decision in the following JSON format:
{{"choice": "continue" or "review" or "end",
"question_id": current question ID,
"previous_id": ID of the text with highest hallucination rate
}}"""

# Modified policy agent prompt with restricted options
RESTRICTED_POLICY_AGENT_PROMPT = """

You are part of a Multi-Agent System engaged in a summarization competition against several opponents, where each agent generates multiple candidate summary texts based on the original news article. 
The objective is to produce summaries with the lowest possible hallucination (i.e., highest factual consistency, reflected in higher hallucination scores) while minimizing resource usage—including API calls, token consumption, time, and review attempts, all of which are penalized when excessive. 
Your ultimate goal is to outperform your competitors by achieving the best balance between factual accuracy and resource efficiency, and to emerge as the final winner.

Your role is to make strategic decisions about what to do next. Here is your current status:
- MAS ID: {mas_id}
- Completed texts: {completed_texts}/{total_texts}
- Remaining texts: {remaining_texts}
- Current token usage: {token_usage} tokens
- API calls made: {api_calls}
- Review calls made: {review_calls}
- Current average hallucination score: {avg_score}
- Time elapsed: {time_elapsed:.2f} seconds

IMPORTANT: Because we just completed a review, you can only choose one of these options:
- CONTINUE: Process the next text and generate a new summary 
- END: Stop processing if all texts have been summarized

You CANNOT choose REVIEW at this time.

The current text summarization task is at ID {current_id} out of {max_id} texts.
The text with the highest hallucination rate so far is ID {worst_id}, with hallucination score: {worst_score}, where 0 means most hallucinated and 1 means least hallucinated.

{opponent_info}

IMPORTANT: You must output your decision in the following JSON format:
{{"choice": "continue" or "end",
"question_id": current question ID,
"previous_id": ID of the text with highest hallucination rate
}}"""

SUMMARY_AGENT_PROMPT = """You are a chat bot answering questions using data. You must stick to the answers provided solely by the text in the passage provided.

You are asked the question 'Provide a concise summary of the following passage, covering the core pieces of information described.'

{passage}

IMPORTANT: You must output your response in the following JSON format:
{{"summary": "your summary here"}}"""

REVIEW_AGENT_PROMPT = """You are a chat bot answering questions using data. You must stick to the answers provided solely by the text in the passage provided.

You previously summarized the following passage, but your summary contained hallucinations (hallucination score: {score}, where 0 means most hallucinated and 1 means least hallucinated), which means factual inconsistencies occurred.

Original passage:
{passage}

Your previous summary:
{previous_summary}

Please provide a new, more accurate summary that strictly adheres to the information in the passage. Focus on improving factual consistency and removing any information not present in the original text.

IMPORTANT: You must output your response in the following JSON format:
{{"summary": "your revised summary here"}}"""

# Add a lock for hallucination_model
hallucination_model_lock = threading.Lock()

# Stats tracking for each MAS system
mas1_stats = {
    "input_tokens": 0,
    "output_tokens": 0,
    "api_calls": 0,
    "policy_calls": 0,
    "summary_calls": 0,
    "review_calls": 0,
    "policy_in_tokens": 0,
    "policy_out_tokens": 0,
    "summary_in_tokens": 0,
    "summary_out_tokens": 0,
    "review_in_tokens": 0,
    "review_out_tokens": 0,
    "start_time": None,
    "total_time": 0
}

mas2_stats = {
    "input_tokens": 0,
    "output_tokens": 0,
    "api_calls": 0,
    "policy_calls": 0,
    "summary_calls": 0,
    "review_calls": 0,
    "policy_in_tokens": 0,
    "policy_out_tokens": 0,
    "summary_in_tokens": 0,
    "summary_out_tokens": 0,
    "review_in_tokens": 0,
    "review_out_tokens": 0,
    "start_time": None,
    "total_time": 0
}

# Calculate hallucination score using the model
def calculate_hallucination_score(hallucination_model, passage, summary, batch_size=1):
    if not passage or not summary or summary == "API ERROR":
        return 0
        
    pairs = [(passage, summary)]
    
    # Use hallucination model to predict score with proper locking
    with hallucination_model_lock:
        scores = hallucination_model.predict(pairs).tolist()
    return scores[0]

# Initialize the data structure
def initialize_data(cohere_data, specific_ids, reverse=False):
    data = []
    
    ids = specific_ids.copy()
    if reverse:
        ids.reverse()
    
    for i in ids:
        data.append({
            "id": i,
            "passage": cohere_data[i]['source'],
            "summary": None,
            "initial_score": None,
            "final_score": None,
            "summary_calls": 0,
            "review_calls": 0,
            "in_tokens": 0,  # Track input tokens for this ID
            "out_tokens": 0  # Track output tokens for this ID
        })
    
    return data

# Find the worst example (with highest hallucination rate)
def find_worst_example(data, review_time):
    worst_score = float('inf')
    worst_id = None
    
    # First find samples with review_calls < review_time with the highest hallucination rate
    for item in data:
        if item["final_score"] is not None and item["final_score"] < worst_score and item["review_calls"] < review_time:
            worst_score = item["final_score"]
            worst_id = item["id"]
    
    # If all samples have been reviewed review_time times or more, consider all samples
    if worst_id is None:
        for item in data:
            if item["final_score"] is not None and item["final_score"] < worst_score:
                worst_score = item["final_score"]
                worst_id = item["id"]
    
    if worst_id is None and len(data) > 0:
        return data[0]["id"]  # Default to first example if no scores yet
    
    return worst_id

# Helper function: Find array index by ID value
def find_index_by_id(data, id_value):
    for idx, item in enumerate(data):
        if item["id"] == id_value:
            return idx
    return None  # Return None if ID not found

# Calculate statistics for reference agent
def calculate_reference_stats(data):
    completed_texts = sum(1 for item in data if item["summary"] is not None)
    remaining_texts = len(data) - completed_texts
    
    # Calculate average hallucination score for completed texts
    valid_scores = [item["final_score"] for item in data if item["final_score"] is not None]
    avg_score = sum(valid_scores) / len(valid_scores) if valid_scores else 0
    
    # Count total review calls
    total_review_calls = sum(item["review_calls"] for item in data)
    
    # Find worst cases (lowest scores)
    scored_items = [(item["id"], item["final_score"], item["review_calls"]) 
                    for item in data if item["final_score"] is not None]
    sorted_items = sorted(scored_items, key=lambda x: x[1])
    
    worst_cases = []
    for i in range(min(5, len(sorted_items))):
        if i < len(sorted_items):
            worst_cases.append({
                "id": sorted_items[i][0],
                "score": sorted_items[i][1],
                "review_calls": sorted_items[i][2]
            })
    
    return {
        "completed_texts": completed_texts,
        "remaining_texts": remaining_texts,
        "avg_score": avg_score,
        "total_review_calls": total_review_calls,
        "worst_cases": worst_cases
    }

# Save data to JSON file
def save_data(data, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# Load data from JSON file
def load_data(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract summary from response text
def extract_summary_from_response(response_text):
    # Check for JSON in Markdown code blocks
    code_block_match = re.search(r'```(?:json)?\s*{.*?"summary":\s*"(.*?)"\s*}.*?```', response_text, re.DOTALL)
    if code_block_match:
        summary_text = code_block_match.group(1)
        return {"summary": summary_text}
    
    # Try to match double-quoted JSON
    summary_match = re.search(r'"summary":\s*"(.*?)"(?=\s*[,}])', response_text, re.DOTALL)
    if summary_match:
        summary_text = summary_match.group(1)
        return {"summary": summary_text}
    
    # Try to match single-quoted JSON
    summary_match = re.search(r"'summary':\s*'(.*?)'(?=\s*[,}])", response_text, re.DOTALL)
    if summary_match:
        summary_text = summary_match.group(1)
        return {"summary": summary_text}
    
    # Default return the full response
    return {"summary": response_text}

# Load hallucination evaluation model
def load_hallucination_model():
    return AutoModelForSequenceClassification.from_pretrained(
        'vectara/hallucination_evaluation_model', trust_remote_code=True
    ) 