from datasets import load_dataset
from openai import OpenAI
import time
import json
import os
from tqdm import tqdm
import re
import copy
from concurrent.futures import ThreadPoolExecutor
import pdb
from queue import Queue
import sys
from io import StringIO
import threading

# Import from utils.py
from utils import (
    POLICY_AGENT_PROMPT, RESTRICTED_POLICY_AGENT_PROMPT, 
    SUMMARY_AGENT_PROMPT, REVIEW_AGENT_PROMPT,
    hallucination_model_lock, mas1_stats, mas2_stats,
    calculate_hallucination_score, initialize_data, find_worst_example,
    find_index_by_id, calculate_reference_stats, save_data, load_data,
    extract_summary_from_response, load_hallucination_model
)

# Configuration
MODEL = 'gpt-4o-mini'
# MODEL = 'qwen-max'
# MODEL = 'deepseek-v3-250324'
# MODEL = 'gemini-2.0-flash'
# MODEL = 'grok-3-beta'
# MODEL = 'step-2-mini'
# MODEL = 'glm-4v-flash'

MAX_WORKERS = 100  # Number of concurrent API calls
test_num = 1000

review_time = 3

# Hallucination threshold coefficient - when score is <= this value, strongly recommend review
HALLUCINATION_THRESHOLD = 0.85

# Define specific IDs
specific_ids = [i for i in range(0, 831) if i not in [15, 29, 202, 217, 396, 485, 483, 405, 440, 533, 587, 589, 611, 619, 703, 709, 731, 774, 727, 732, 755, 809, 817]]

# Replace illegal characters in model name with underscores
model_filename = MODEL.replace('/', '_').replace(':', '_')
mas1_file = f'mas1_{model_filename}_{len(specific_ids)}_thre_{HALLUCINATION_THRESHOLD}_review_{review_time}.json'
mas2_file = f'mas2_{model_filename}_{len(specific_ids)}_thre_{HALLUCINATION_THRESHOLD}_review_{review_time}.json'
model_log_filename = f'log_{model_filename}_{len(specific_ids)}_thre_{HALLUCINATION_THRESHOLD}_review_{review_time}.txt'

# Initialize the OpenAI client
client = OpenAI(
    api_key='your_api_key',
    base_url='your_base_url'
)

# Load the hallucination evaluation model
hallucination_model = load_hallucination_model()

# Stats tracking for each MAS system
mas1_stats = {
    "input_tokens": 0,
    "output_tokens": 0,
    "api_calls": 0,
    "policy_calls": 0,
    "summary_calls": 0,
    "review_calls": 0,
    "policy_in_tokens": 0,
    "policy_out_tokens": 0,
    "summary_in_tokens": 0,
    "summary_out_tokens": 0,
    "review_in_tokens": 0,
    "review_out_tokens": 0,
    "start_time": None,
    "total_time": 0
}

mas2_stats = {
    "input_tokens": 0,
    "output_tokens": 0,
    "api_calls": 0,
    "policy_calls": 0,
    "summary_calls": 0,
    "review_calls": 0,
    "policy_in_tokens": 0,
    "policy_out_tokens": 0,
    "summary_in_tokens": 0,
    "summary_out_tokens": 0,
    "review_in_tokens": 0,
    "review_out_tokens": 0,
    "start_time": None,
    "total_time": 0
}

# API call function for generic agent interaction
def call_agent_api(prompt, agent_type, mas_id):
    global mas1_stats, mas2_stats
    
    conversation = [
        {"role": "user", "content": prompt}
    ]
    
    max_retries = 1000000
    retry_count = 0
    
    while retry_count < max_retries:
        try:
            completion = client.chat.completions.create(
                model=MODEL,
                messages=conversation,
                temperature=0.0
            )
            
            response = completion.choices[0].message.content

            input_tokens = completion.usage.prompt_tokens
            output_tokens = completion.usage.completion_tokens
            
            # Update token usage stats for the appropriate MAS
            if mas_id == "mas1":
                mas1_stats["input_tokens"] += input_tokens
                mas1_stats["output_tokens"] += output_tokens
                mas1_stats["api_calls"] += 1
                
                # Update agent-specific token counts
                if agent_type == "policy":
                    mas1_stats["policy_calls"] += 1
                    mas1_stats["policy_in_tokens"] += input_tokens
                    mas1_stats["policy_out_tokens"] += output_tokens
                elif agent_type == "summary":
                    mas1_stats["summary_calls"] += 1
                    mas1_stats["summary_in_tokens"] += input_tokens
                    mas1_stats["summary_out_tokens"] += output_tokens
                elif agent_type == "review":
                    mas1_stats["review_calls"] += 1
                    mas1_stats["review_in_tokens"] += input_tokens
                    mas1_stats["review_out_tokens"] += output_tokens
            else:
                mas2_stats["input_tokens"] += input_tokens
                mas2_stats["output_tokens"] += output_tokens
                mas2_stats["api_calls"] += 1
                
                # Update agent-specific token counts
                if agent_type == "policy":
                    mas2_stats["policy_calls"] += 1
                    mas2_stats["policy_in_tokens"] += input_tokens
                    mas2_stats["policy_out_tokens"] += output_tokens
                elif agent_type == "summary":
                    mas2_stats["summary_calls"] += 1
                    mas2_stats["summary_in_tokens"] += input_tokens
                    mas2_stats["summary_out_tokens"] += output_tokens
                elif agent_type == "review":
                    mas2_stats["review_calls"] += 1
                    mas2_stats["review_in_tokens"] += input_tokens
                    mas2_stats["review_out_tokens"] += output_tokens
            
            return {
                "response": response,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "agent_type": agent_type
            }
            
        except Exception as e:
            retry_count += 1
            print(f"API call error for {agent_type} agent in {mas_id}, attempt {retry_count}/{max_retries}: {str(e)}")
            
            if retry_count >= max_retries:
                print(f"Failed after {max_retries} attempts for {agent_type} agent in {mas_id}")
                return {
                    "response": "API ERROR",
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "agent_type": agent_type
                }
            
            wait_time = 1  # Exponential backoff
            time.sleep(wait_time)

# Calculate hallucination score using the model
def calculate_hallucination_score(model, passage, summary):
    if not passage or not summary or summary == "API ERROR":
        return 0
        
    pairs = [(passage, summary)]
    
    # Use hallucination model to predict score with proper locking
    with hallucination_model_lock:
        scores = model.predict(pairs).tolist()
    return scores[0]

# Initialize the data structure
def initialize_data(cohere_data, specific_ids, reverse=False):
    data = []
    
    ids = specific_ids.copy()
    if reverse:
        ids.reverse()
    
    for i in ids:
        data.append({
            "id": i,
            "passage": cohere_data[i]['source'],
            "summary": None,
            "initial_score": None,
            "final_score": None,
            "summary_calls": 0,
            "review_calls": 0,
            "in_tokens": 0,  # Track input tokens for this ID
            "out_tokens": 0  # Track output tokens for this ID
        })
    
    return data

# Find the worst example (with highest hallucination rate)
def find_worst_example(data, review_time):
    worst_score = float('inf')
    worst_id = None
    
    # First find samples with review_calls < review_time with the highest hallucination rate
    for item in data:
        if item["final_score"] is not None and item["final_score"] < worst_score and item["review_calls"] < review_time:
            worst_score = item["final_score"]
            worst_id = item["id"]
    
    # If all samples have been reviewed review_time times or more, consider all samples
    if worst_id is None:
        for item in data:
            if item["final_score"] is not None and item["final_score"] < worst_score:
                worst_score = item["final_score"]
                worst_id = item["id"]
    
    if worst_id is None and len(data) > 0:
        return data[0]["id"]  # Default to first example if no scores yet
    
    return worst_id

# Helper function: Find array index by ID value
def find_index_by_id(data, id_value):
    for idx, item in enumerate(data):
        if item["id"] == id_value:
            return idx
    return None  # Return None if ID not found

# Calculate statistics for reference agent
def calculate_reference_stats(data):
    completed_texts = sum(1 for item in data if item["summary"] is not None)
    remaining_texts = len(data) - completed_texts
    
    # Calculate average hallucination score for completed texts
    valid_scores = [item["final_score"] for item in data if item["final_score"] is not None]
    avg_score = sum(valid_scores) / len(valid_scores) if valid_scores else 0
    
    # Count total review calls
    total_review_calls = sum(item["review_calls"] for item in data)
    
    # Find worst cases (lowest scores)
    scored_items = [(item["id"], item["final_score"], item["review_calls"]) 
                    for item in data if item["final_score"] is not None]
    sorted_items = sorted(scored_items, key=lambda x: x[1])
    
    worst_cases = []
    for i in range(min(5, len(sorted_items))):
        if i < len(sorted_items):
            worst_cases.append({
                "id": sorted_items[i][0],
                "score": sorted_items[i][1],
                "review_calls": sorted_items[i][2]
            })
    
    return {
        "completed_texts": completed_texts,
        "remaining_texts": remaining_texts,
        "avg_score": avg_score,
        "total_review_calls": total_review_calls,
        "worst_cases": worst_cases
    }

# Policy agent decision
def policy_agent(current_id, max_id, data, mas_id, opponent_info=None, just_reviewed=False):
    # Add early decision logic at the beginning of the function
    # If we just reviewed AND there are still texts to process, automatically decide to continue
    all_summarized = all(item["summary"] is not None for item in data)
    if just_reviewed and current_id < len(data) and not all_summarized:
        # Return a decision to continue without making an API call
        return {
            "choice": "continue", 
            "question_id": data[current_id]["id"], 
            "previous_id": find_worst_example(data, review_time)
        }, 0, 0  # No tokens used
    
    # If we just reviewed and all texts are summarized, automatically end
    if just_reviewed and all_summarized:
        return {
            "choice": "end",
            "question_id": data[current_id]["id"] if current_id < len(data) else None,
            "previous_id": find_worst_example(data, review_time)
        }, 0, 0  # No tokens used
    
    # Calculate stats for the prompt
    completed_texts = sum(1 for item in data if item["summary"] is not None)
    remaining_texts = len(data) - completed_texts
    
    # Get token usage and API calls from the appropriate MAS stats
    if mas_id == "mas1":
        token_usage = mas1_stats["input_tokens"] + mas1_stats["output_tokens"]
        api_calls = mas1_stats["api_calls"]
        time_elapsed = time.time() - mas1_stats["start_time"] if mas1_stats["start_time"] else 0
    else:
        token_usage = mas2_stats["input_tokens"] + mas2_stats["output_tokens"]
        api_calls = mas2_stats["api_calls"]
        time_elapsed = time.time() - mas2_stats["start_time"] if mas2_stats["start_time"] else 0
    
    # Calculate total review calls
    review_calls = sum(item["review_calls"] for item in data)
    
    # Calculate average hallucination score
    valid_scores = [item["final_score"] for item in data if item["final_score"] is not None]
    avg_score = sum(valid_scores) / len(valid_scores) if valid_scores else 0
        
    # Find worst example so far
    worst_id = find_worst_example(data, review_time)
    
    # Find the array index for the worst ID
    worst_idx = find_index_by_id(data, worst_id)
    
    # Format opponent info if available
    opponent_info_text = ""
    if opponent_info:
        opponent_info_text = f"""
Information about your opponent:
- Completed texts: {opponent_info['completed_texts']}/{len(data)}
- Remaining texts: {opponent_info['remaining_texts']}
- Current average hallucination score: {opponent_info['avg_score']:.4f}
- Total review calls: {opponent_info['total_review_calls']}

Top 5 worst cases from your opponent:
"""
        for i, case in enumerate(opponent_info['worst_cases']):
            opponent_info_text += f"- Case {i+1}: ID {case['id']}, Score: {case['score']:.4f}, Review calls: {case['review_calls']}\n"
    
    # Choose prompt based on whether we just performed a review
    if just_reviewed:
        prompt_template = RESTRICTED_POLICY_AGENT_PROMPT
    else:
        prompt_template = POLICY_AGENT_PROMPT
    
    # If no worst example yet, provide basic information
    if worst_idx is None:
        prompt = prompt_template.format(
            mas_id=mas_id,
            completed_texts=completed_texts,
            total_texts=len(data),
            remaining_texts=remaining_texts,
            token_usage=token_usage,
            api_calls=api_calls,
            review_calls=review_calls,
            avg_score=f"{avg_score:.4f}",
            time_elapsed=time_elapsed,
            current_id=data[current_id]["id"] if current_id < len(data) else "N/A",
            max_id=data[-1]["id"],
            worst_id="None",
            worst_score="N/A",
            opponent_info=opponent_info_text
        )
    else:
        base_prompt = prompt_template.format(
            mas_id=mas_id,
            completed_texts=completed_texts,
            total_texts=len(data),
            remaining_texts=remaining_texts,
            token_usage=token_usage,
            api_calls=api_calls,
            review_calls=review_calls,
            avg_score=f"{avg_score:.4f}",
            time_elapsed=time_elapsed,
            current_id=data[current_id]["id"] if current_id < len(data) else "N/A",
            max_id=data[-1]["id"],
            worst_id=worst_id,
            worst_score=f"{data[worst_idx]['final_score']:.4f}" if data[worst_idx]["final_score"] is not None else "N/A",
            opponent_info=opponent_info_text
        )
        
        # Check if the worst example's score is below threshold and hasn't reached review limit
        # And we're not using the restricted prompt (just reviewed)
        if (not just_reviewed and 
            data[worst_idx]["final_score"] is not None and 
            data[worst_idx]["final_score"] <= HALLUCINATION_THRESHOLD and 
            data[worst_idx]["review_calls"] < review_time):
            
            # Add the recommendation to the prompt
            recommendation = f"\nIMPORTANT RECOMMENDATION: The hallucination score of {data[worst_idx]['final_score']:.4f} is below our acceptable threshold of {HALLUCINATION_THRESHOLD}. It's STRONGLY RECOMMENDED that you choose to REVIEW this example to improve its factual accuracy.\n"
            
            # Find where to insert the recommendation (before the final "IMPORTANT: You must output...")
            split_text = "IMPORTANT: You must output your decision"
            if split_text in base_prompt:
                parts = base_prompt.split(split_text, 1)
                prompt = parts[0] + recommendation + "IMPORTANT: You must output your decision" + parts[1]
                print(f"Adding review recommendation for {mas_id} as score {data[worst_idx]['final_score']:.4f} is below threshold {HALLUCINATION_THRESHOLD}")
            else:
                prompt = base_prompt  # Fallback if split doesn't work
        else:
            prompt = base_prompt

    result = call_agent_api(prompt, "policy", mas_id)
    
    try:
        decision = json.loads(result["response"])
        # If we just reviewed, force choice to be either continue or end
        if just_reviewed and decision.get("choice") == "review":
            print(f"Overriding 'review' choice to 'continue' for {mas_id} as we just performed a review")
            decision["choice"] = "continue"

        # Modify this logic to only force continue when there are unprocessed texts, no longer forcing end choice for the last question
        if decision.get("choice") == "end" and not all_summarized:
            print(f"Overriding 'end' choice to 'continue' for {mas_id} as not all texts have summaries yet")
            decision["choice"] = "continue"
            
        # If this is the last question and review was chosen instead of end, allow the review operation
        current_idx_is_last = current_id >= len(data) - 1
        if current_idx_is_last and decision.get("choice") == "review" and all_summarized:
            print(f"{mas_id.upper()}: Allowing review for the last question with ID {worst_id}")
            
        return decision, result["input_tokens"], result["output_tokens"]
    except json.JSONDecodeError:
        # Try to extract JSON content
        # First try from Markdown code blocks
        json_block_match = re.search(r"```(?:json)?\n(.*?)\n```", result["response"], re.DOTALL)
        if json_block_match:
            try:
                decision = json.loads(json_block_match.group(1))
                return decision, result["input_tokens"], result["output_tokens"]
            except:
                print('11111')
                pass
        
        # Clean response and try again
        response_clean = re.sub(r'^.*?(\{.*\}).*$', r'\1', result["response"], flags=re.DOTALL)
        try:
            decision = json.loads(response_clean)
        except Exception as e:
            print(f"JSON parsing error: {e}")
            # directly use regex to extract the last JSON part from response_clean
            # find the JSON part that starts with "choice", from the nearest { to the end
            json_match = re.search(r'(\{[^{]*"choice".*?\})\s*$', response_clean, re.DOTALL)
            if json_match:
                json_str = json_match.group(1)
                # don't use json.loads, just build the dict
                choice_match = re.search(r'"choice":\s*"(\w+)"', json_str)
                question_id_match = re.search(r'"question_id":\s*(\d+)', json_str)
                previous_id_match = re.search(r'"previous_id":\s*(\d+)', json_str)
                
                decision = {}
                if choice_match:
                    decision["choice"] = choice_match.group(1)
                if question_id_match:
                    decision["question_id"] = int(question_id_match.group(1))
                if previous_id_match:
                    decision["previous_id"] = int(previous_id_match.group(1))
            else:
                print("cannot parse json from response")
                raise
        return decision, result["input_tokens"], result["output_tokens"]

# Summary agent
def summary_agent(passage, mas_id):
    prompt = SUMMARY_AGENT_PROMPT.format(passage=passage)
    
    result = call_agent_api(prompt, "summary", mas_id)
    
    try:
        # Parse JSON
        summary_response = json.loads(result["response"])
        
        return summary_response.get("summary", result["response"]), result["input_tokens"], result["output_tokens"]
    except json.JSONDecodeError:
        # Extract summary field from response text
        response_text = result["response"]
        summary_response = extract_summary_from_response(response_text)
        
        return summary_response.get("summary", result["response"]), result["input_tokens"], result["output_tokens"]

# Review agent
def review_agent(passage, previous_summary, score, mas_id):
    prompt = REVIEW_AGENT_PROMPT.format(
        passage=passage,
        previous_summary=previous_summary,
        score=score
    )
    
    result = call_agent_api(prompt, "review", mas_id)
    
    try:
        # Parse JSON
        review_response = json.loads(result["response"])
        
        return review_response.get("summary", result["response"]), result["input_tokens"], result["output_tokens"]
    except json.JSONDecodeError:
        response_text = result["response"]
        review_response = extract_summary_from_response(response_text)
        
        return review_response.get("summary", result["response"]), result["input_tokens"], result["output_tokens"]

# Save data to JSON file
def save_data(data, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# Load data from JSON file
def load_data(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return json.load(f)

# Calculate hallucination score using the model
def calculate_hallucination_score_wrapper(passage, summary):
    return calculate_hallucination_score(hallucination_model, passage, summary)

# Process a single MAS step
def process_mas_step(data, current_idx, mas_id, opponent_info, just_reviewed):
    # Get the current ID being processed
    current_id = data[current_idx]["id"] if current_idx < len(data) else None
    
    # if we have exceeded the number of questions, return the end state
    if current_idx >= len(data):
        return data, current_idx, True, False
    
    # Policy agent decision
    decision, policy_in_tokens, policy_out_tokens = policy_agent(current_idx, len(data), data, mas_id, opponent_info, just_reviewed)
    
    # Update token counts for current question
    data[current_idx]["in_tokens"] += policy_in_tokens
    data[current_idx]["out_tokens"] += policy_out_tokens
    
    # Reset review flag for next iteration
    next_just_reviewed = False
    
    print(f"\n{mas_id.upper()} Policy agent decision for ID {current_id}: {decision['choice']}")
    
    choice = decision.get("choice", "continue")
    
    if choice == "end":
        print(f"{mas_id.upper()} Policy agent decided to end processing")
        # record the time when MAS is done
        if mas_id == "mas1" and mas1_stats["total_time"] == 0:
            mas1_stats["total_time"] = time.time() - mas1_stats["start_time"]
        elif mas_id == "mas2" and mas2_stats["total_time"] == 0:
            mas2_stats["total_time"] = time.time() - mas2_stats["start_time"]
        return data, current_idx, True, False
        
    elif choice == "continue":
        # Summary agent generates new summary
        print(f"{mas_id.upper()}: Generating summary for ID {current_id}")
        
        passage = data[current_idx]["passage"]
        summary_text, summary_in_tokens, summary_out_tokens = summary_agent(passage, mas_id)
        
        # Update token counts for current question
        data[current_idx]["in_tokens"] += summary_in_tokens
        data[current_idx]["out_tokens"] += summary_out_tokens
        
        # Update summary call counter for current question
        data[current_idx]["summary_calls"] += 1
        
        # Update the data
        data[current_idx]["summary"] = summary_text
        
        # Calculate hallucination score
        score = calculate_hallucination_score_wrapper(passage, summary_text)
        
        # Set initial and final scores
        if data[current_idx]["initial_score"] is None:
            data[current_idx]["initial_score"] = score
        data[current_idx]["final_score"] = score
        
        print(f"{mas_id.upper()}: Generated summary for ID {current_id} with hallucination score: {score:.4f}")
        
        # Move to next text
        current_idx += 1
        
    elif choice == "review":
        # Review agent improves existing summary
        previous_id = decision.get("previous_id")
        
        # Find corresponding data index
        previous_idx = find_index_by_id(data, previous_id)
        
        if previous_idx is None:
            print(f"{mas_id.upper()}: Warning - Could not find ID {previous_id} for review. Continuing.")
            current_idx += 1
            return data, current_idx, False, False
        
        print(f"{mas_id.upper()}: Reviewing summary for ID {previous_id}")
        
        passage = data[previous_idx]["passage"]
        previous_summary = data[previous_idx]["summary"]
        previous_score = data[previous_idx]["final_score"]
        
        # If initial_score is not set, set it
        if data[previous_idx]["initial_score"] is None:
            data[previous_idx]["initial_score"] = previous_score
        
        # Update review call count
        data[previous_idx]["review_calls"] += 1

        # Generate improved summary
        improved_summary, review_in_tokens, review_out_tokens = review_agent(
            passage, previous_summary, previous_score, mas_id
        )
        
        # Update token counts for reviewed question
        data[previous_idx]["in_tokens"] += review_in_tokens
        data[previous_idx]["out_tokens"] += review_out_tokens
        
        # Calculate new hallucination score
        new_score = calculate_hallucination_score_wrapper(passage, improved_summary)
        
        # Only update summary and score if new score is better
        if new_score > previous_score:
            # Update the data
            data[previous_idx]["summary"] = improved_summary
            data[previous_idx]["final_score"] = new_score
            print(f"{mas_id.upper()}: Reviewed summary for ID {previous_id}. Old score: {previous_score:.4f}, New score: {new_score:.4f}")
        else:
            print(f"{mas_id.upper()}: Reviewed summary for ID {previous_id}. New score ({new_score:.4f}) is not better than old score ({previous_score:.4f}). Keeping original summary.")
        
        # return True as next_just_reviewed flag, and as review_occurred flag to update opponent_info
        return data, current_idx, False, True
    
    return data, current_idx, False, next_just_reviewed

# main loop - actually running two MAS systems in parallel
def run_mas(mas_id, data, current_idx, progress_bar, ref_info_queue, result_queue, just_reviewed=False):
    """
    run a single MAS system until completion
    
    Args:
        mas_id: MAS system ID ('mas1' or 'mas2')
        data: data for the MAS system
        current_idx: current index being processed
        progress_bar: progress bar object
        ref_info_queue: queue for receiving opponent information
        result_queue: queue for sending result information
        just_reviewed: whether a review was just performed
    """
    mas_done = False
    
    # opponent information flag
    opponent_ref_info = None
    
    while not mas_done:
        # check if there is new opponent information in the queue (non-blocking)
        try:
            queue_item = ref_info_queue.get_nowait()
            # check if it's the special "DONE" signal
            if queue_item == "DONE":
                # the other MAS is done, but we continue running
                print(f"{mas_id.upper()}: The opponent is done. We continue running.")
            else:
                # update opponent information
                opponent_ref_info = queue_item
        except:
            # the queue is empty, continue
            pass
        
        # execute one step of MAS
        data, current_idx, mas_done, review_occurred = process_mas_step(
            data, current_idx, mas_id, opponent_ref_info, just_reviewed
        )
        
        # update the progress bar
        progress_bar.n = sum(1 for item in data if item["summary"] is not None)
        progress_bar.refresh()
        
        # save the updated data
        save_data(data, mas1_file if mas_id == "mas1" else mas2_file)
        
        # if a review was performed, send the latest status to the other MAS
        if review_occurred:
            try:
                result_queue.put_nowait(calculate_reference_stats(data))
            except:
                # the queue is full, ignore
                pass
        
        # update the review flag
        just_reviewed = review_occurred
    
    # MAS done, record the completion time
    if mas_id == "mas1" and mas1_stats["total_time"] == 0:
        mas1_stats["total_time"] = time.time() - mas1_stats["start_time"]
    elif mas_id == "mas2" and mas2_stats["total_time"] == 0:
        mas2_stats["total_time"] = time.time() - mas2_stats["start_time"]
    
    # send the completion signal to the queue
    try:
        result_queue.put_nowait("DONE")
    except:
        pass
    
    progress_bar.close()
    
    return data

# setup a log file to record the function output
def setup_logging(filename):
    # create a stringIO object to capture the output
    log_capture = StringIO()
    
    # save the original standard output
    original_stdout = sys.stdout
    
    # create a custom output class, writing to both console and capture object
    class TeeOutput:
        def write(self, message):
            original_stdout.write(message)
            log_capture.write(message)
            
        def flush(self):
            original_stdout.flush()
    
    # replace the standard output
    sys.stdout = TeeOutput()
    
    return log_capture, original_stdout

def main():
    global mas1_stats, mas2_stats
    
    # setup logging
    
    log_capture, original_stdout = setup_logging(model_log_filename)
    
    # Load the dataset
    print("Loading dataset...")
    ds = load_dataset("vectara/leaderboard_results")

    # Filter for Cohere-Chat model
    cohere_data = ds['train'].filter(lambda x: x['model'] == 'cohere/Cohere-Chat')

    if test_num != 'all':
        cohere_data = cohere_data.select(range(min(test_num, len(cohere_data))))

    print(f"Found {len(cohere_data)} entries for Cohere-Chat model")
    
    # Initialize mas1 data (forward order)
    if os.path.exists(mas1_file):
        print(f"Loading existing data for MAS1 from {mas1_file}")
        mas1_data = load_data(mas1_file)
    else:
        print("Initializing MAS1 data structure...")
        mas1_data = initialize_data(cohere_data, specific_ids, reverse=False)
        save_data(mas1_data, mas1_file)
    
    # Initialize mas2 data (reverse order)
    if os.path.exists(mas2_file):
        print(f"Loading existing data for MAS2 from {mas2_file}")
        mas2_data = load_data(mas2_file)
    else:
        print("Initializing MAS2 data structure...")
        mas2_data = initialize_data(cohere_data, specific_ids, reverse=True)
        save_data(mas2_data, mas2_file)
    
    # Start timers
    mas1_stats["start_time"] = time.time()
    mas2_stats["start_time"] = time.time()
    
    # Initialize current indices
    mas1_current_idx = 0
    mas2_current_idx = 0
    
    # Initialize flags
    mas1_done = False
    mas2_done = False
    mas1_just_reviewed = False
    mas2_just_reviewed = False
    
    # Initialize reference agent data
    mas1_ref_info = None
    mas2_ref_info = None
    
    # Create progress bars
    mas1_progress = tqdm(total=len(mas1_data), desc="MAS1 Processing", unit="question", position=0)
    mas2_progress = tqdm(total=len(mas2_data), desc="MAS2 Processing", unit="question", position=1)
    
    # create queues for MAS communication, set a large maximum size
    mas1_to_mas2_queue = Queue(maxsize=100)  # MAS1 sends information to MAS2
    mas2_to_mas1_queue = Queue(maxsize=100)  # MAS2 sends information to MAS1
    
    # create two independent threads to run two MAS systems using a thread pool
    with ThreadPoolExecutor(max_workers=2) as executor:
        # submit two MAS tasks
        mas1_future = executor.submit(
            run_mas, "mas1", mas1_data, mas1_current_idx, mas1_progress, 
            mas2_to_mas1_queue, mas1_to_mas2_queue, mas1_just_reviewed
        )
        
        mas2_future = executor.submit(
            run_mas, "mas2", mas2_data, mas2_current_idx, mas2_progress, 
            mas1_to_mas2_queue, mas2_to_mas1_queue, mas2_just_reviewed
        )
        
        # wait for both MAS to complete
        mas1_data = mas1_future.result()
        mas2_data = mas2_future.result()
    
    # Close progress bars
    mas1_progress.close()
    mas2_progress.close()
    
    # Calculate final stats - only calculate time for MASs that haven't set total_time
    if mas1_stats["total_time"] == 0:
        mas1_stats["total_time"] = time.time() - mas1_stats["start_time"]
    if mas2_stats["total_time"] == 0:
        mas2_stats["total_time"] = time.time() - mas2_stats["start_time"]
    
    # Calculate final hallucination scores
    mas1_final_scores = [item["final_score"] for item in mas1_data if item["final_score"] is not None]
    mas2_final_scores = [item["final_score"] for item in mas2_data if item["final_score"] is not None]
    
    mas1_avg_score = sum(mas1_final_scores) / len(mas1_final_scores) if mas1_final_scores else 0
    mas2_avg_score = sum(mas2_final_scores) / len(mas2_final_scores) if mas2_final_scores else 0
    
    # calculate the average of initial and final hallucination scores
    mas1_initial_scores = [item["initial_score"] for item in mas1_data if item["initial_score"] is not None]
    mas2_initial_scores = [item["initial_score"] for item in mas2_data if item["initial_score"] is not None]
    
    mas1_avg_initial_score = sum(mas1_initial_scores) / len(mas1_initial_scores) if mas1_initial_scores else 0
    mas2_avg_initial_score = sum(mas2_initial_scores) / len(mas2_initial_scores) if mas2_initial_scores else 0
    
    # Print final results
    print("\n" + "="*50)
    print("FINAL RESULTS:")
    print("="*50)
    
    print("\nMAS1 (Forward Order):")
    print(f"- Questions processed: {len(mas1_final_scores)}/{len(mas1_data)}")
    print(f"- Initial average hallucination score: {mas1_avg_initial_score:.4f}")
    print(f"- Final average hallucination score: {mas1_avg_score:.4f}")
    print(f"- Total API calls: {mas1_stats['api_calls']}")
    print(f"- Total policy agent API calls: {mas1_stats['policy_calls']}")
    print(f"- Total summary agent API calls: {mas1_stats['summary_calls']}")
    print(f"- Total review agent API calls: {mas1_stats['review_calls']}")
    print(f"- Total in tokens: {mas1_stats['input_tokens']}")
    print(f"- Total out tokens: {mas1_stats['output_tokens']}")
    print(f"- Total tokens: {mas1_stats['input_tokens'] + mas1_stats['output_tokens']}")
    print(f"- Total policy tokens: {mas1_stats['policy_in_tokens'] + mas1_stats['policy_out_tokens']}")
    print(f"- Total summary tokens: {mas1_stats['summary_in_tokens'] + mas1_stats['summary_out_tokens']}")
    print(f"- Total review tokens: {mas1_stats['review_in_tokens'] + mas1_stats['review_out_tokens']}")
    print(f"- Total time: {mas1_stats['total_time']:.2f} seconds")
    print(f"- Total review calls: {sum(item['review_calls'] for item in mas1_data)}")
    
    print("\nMAS2 (Reverse Order):")
    print(f"- Questions processed: {len(mas2_final_scores)}/{len(mas2_data)}")
    print(f"- Initial average hallucination score: {mas2_avg_initial_score:.4f}")
    print(f"- Final average hallucination score: {mas2_avg_score:.4f}")
    print(f"- Total API calls: {mas2_stats['api_calls']}")
    print(f"- Total policy agent API calls: {mas2_stats['policy_calls']}")
    print(f"- Total summary agent API calls: {mas2_stats['summary_calls']}")
    print(f"- Total review agent API calls: {mas2_stats['review_calls']}")
    print(f"- Total in tokens: {mas2_stats['input_tokens']}")
    print(f"- Total out tokens: {mas2_stats['output_tokens']}")
    print(f"- Total tokens: {mas2_stats['input_tokens'] + mas2_stats['output_tokens']}")
    print(f"- Total policy tokens: {mas2_stats['policy_in_tokens'] + mas2_stats['policy_out_tokens']}")
    print(f"- Total summary tokens: {mas2_stats['summary_in_tokens'] + mas2_stats['summary_out_tokens']}")
    print(f"- Total review tokens: {mas2_stats['review_in_tokens'] + mas2_stats['review_out_tokens']}")
    print(f"- Total time: {mas2_stats['total_time']:.2f} seconds")
    print(f"- Total review calls: {sum(item['review_calls'] for item in mas2_data)}")
    
    # Determine winner
    print("\nCOMPETITION RESULTS:")
    
    # Calculate normalized scores (higher is better)
    # Formula: hallucination_score / (tokens * time * (api_calls + review_calls))
    alpha = 1
    beta = 0.1

    max_tokens = max(mas1_stats['input_tokens'] + mas1_stats['output_tokens'], mas2_stats['input_tokens'] + mas2_stats['output_tokens'])
    max_time = max(mas1_stats['total_time'], mas2_stats['total_time'])
    max_api_calls = max(mas1_stats['api_calls'], mas2_stats['api_calls'])
    max_review_calls = max(sum(item['review_calls'] for item in mas1_data), sum(item['review_calls'] for item in mas2_data))

    mas1_penalty = (mas1_stats['input_tokens'] + mas1_stats['output_tokens']) / max_tokens + mas1_stats['total_time'] / max_time + mas1_stats['api_calls'] / max_api_calls + sum(item['review_calls'] for item in mas1_data) / max_review_calls
    mas2_penalty = (mas2_stats['input_tokens'] + mas2_stats['output_tokens']) / max_tokens + mas2_stats['total_time'] / max_time + mas2_stats['api_calls'] / max_api_calls + sum(item['review_calls'] for item in mas2_data) / max_review_calls

    mas1_normalized = alpha * mas1_avg_score - beta * mas1_penalty
    mas2_normalized = alpha * mas2_avg_score - beta * mas2_penalty

    print(f"mas1 normalized = {mas1_normalized:.10f} = {alpha} * {mas1_avg_score} - {beta} * ({mas1_stats['input_tokens'] + mas1_stats['output_tokens']} / {max_tokens} + {mas1_stats['total_time']} / {max_time} + {mas1_stats['api_calls']} / {max_api_calls} + {sum(item['review_calls'] for item in mas1_data)} / {max_review_calls})")
    print(f"mas2 normalized = {mas2_normalized:.10f} = {alpha} * {mas2_avg_score} - {beta} * ({mas2_stats['input_tokens'] + mas2_stats['output_tokens']} / {max_tokens} + {mas2_stats['total_time']} / {max_time} + {mas2_stats['api_calls']} / {max_api_calls} + {sum(item['review_calls'] for item in mas2_data)} / {max_review_calls})")
    
    print(f"MAS1 Normalized Score: {mas1_normalized:.10f}")
    print(f"MAS2 Normalized Score: {mas2_normalized:.10f}")
    
    if mas1_normalized > mas2_normalized:
        print("WINNER: MAS1 (Forward Order)")
    elif mas2_normalized > mas1_normalized:
        print("WINNER: MAS2 (Reverse Order)")
    else:
        print("RESULT: TIE")
    
    print("\nDetailed results saved to:")
    print(f"- MAS1: {mas1_file}")
    print(f"- MAS2: {mas2_file}")

    # save the log in the function end
    sys.stdout = original_stdout
    with open(model_log_filename, 'w', encoding='utf-8') as log_file:
        log_file.write(log_capture.getvalue())
    
    print(f"\nLog saved to: {model_log_filename}")

if __name__ == "__main__":
    main()