import asyncio
import pandas as pd
import time
import os
import sys
import numpy as np  # Add numpy for statistical calculations

# Ensure the utils directory is in the Python path
# This assumes the script is run from the root directory or src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src.utils import llm, prompts, string_utils

# Define the configurations to test
CONFIGURATIONS = [
    ("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "BINARY_BASELINE"),
    ("gpt-4.1-nano-2025-04-14", "BINARY_BASELINE"),
    ("gpt-4o-mini", "BINARY_BASELINE"),
    ("meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "BINARY_INTENTION_EVAL_W_8_AGENT_ICL"),
    ("gpt-4.1-nano-2025-04-14", "BINARY_INTENTION_EVAL_COT"),
    ("gpt-4o-mini", "BINARY_INTENTION_EVAL_W_8_DQ_ICL"),
    ("o3-mini", "BINARY_BASELINE"),
    ("gpt-4o", "BINARY_BASELINE"),
]

# --- Reused Code from 1_to_k_binary_eval.py ---

PROMPT_DICT = {"BINARY_BASELINE": prompts.BINARY_BASELINE,
                    "BINARY_INTENTION_EVAL_ORIGINAL": prompts.BINARY_INTENTION_EVAL_ORIGINAL,
                    "BINARY_INTENTION_EVAL_W_9_MIXED_ICL": prompts.BINARY_INTENTION_EVAL_W_9_MIXED_ICL,
                    "BINARY_INTENTION_EVAL_W_8_AGENT_ICL": prompts.BINARY_INTENTION_EVAL_W_8_AGENT_ICL,
                    "BINARY_INTENTION_EVAL_W_8_IMAGE_ICL": prompts.BINARY_INTENTION_EVAL_W_8_IMAGE_ICL,
                    "BINARY_INTENTION_EVAL_W_8_DQ_ICL": prompts.BINARY_INTENTION_EVAL_W_8_DQ_ICL,
                    "BINARY_INTENTION_EVAL_COT": prompts.BINARY_INTENTION_EVAL_COT,
                    "BINARY_INTENTION_W_SAFETY_GUIDELINE_EVAL": prompts.BINARY_INTENTION_W_SAFETY_GUIDELINE_EVAL,
                    "HYPOTHESIS_GENERATION_PROMPT": prompts.HYPOTHESIS_GENERATION_PROMPT,
                  "ADVANCED_HYPOTHESIS_GENERATION_PROMPT": prompts.ADVANCED_HYPOTHESIS_GENERATION_PROMPT,
              "HYPOTHESIS_GENERATION_DQ_ICL_PROMPT": prompts.HYPOTHESIS_GENERATION_DQ_ICL_PROMPT,
              "HYPOTHESIS_GENERATION_IMAGE_ICL_PROMPT": prompts.HYPOTHESIS_GENERATION_IMAGE_ICL_PROMPT,
              "HYPOTHESIS_GENERATION_ICL_PROMPT": prompts.HYPOTHESIS_GENERATION_ICL_PROMPT,
              "NEW_HYPOTHESIS_GENERATION_PROMPT": prompts.NEW_HYPOTHESIS_GENERATION_PROMPT,}

def split_subtasks(task_string):
    """Split multi-line subtask text into a list of lines."""
    try:
        lines = string_utils.extract_list_from_code(task_string)
        lines = [line.strip().replace("'", '') for line in lines if line.strip()]
    except:
        lines = [line.strip().replace("'", '') for line in task_string.split('\\n') if line.strip()]
    return lines

async def run_with_semaphore(sem: asyncio.Semaphore, coro):
    async with sem:
        return await coro

# --- Latency Measurement Function ---

async def measure_task_prompt_latency(task, model_name, prompt_template):
    """
    Processes a task iteratively, measures the time and cost for each LLM call,
    and returns the total time, total cost, and number of successful calls for the task.
    Returns None if any LLM call fails.
    """
    total_time_for_task = 0.0
    total_cost_for_task = 0.0 # Added cost tracking
    num_successful_calls = 0
    lines = split_subtasks(task['task_text'])
    combined = []

    for i, line in enumerate(lines, start=1):
        combined.append(line)
        partial_text = "\\n".join(combined)
        prompt_text = prompt_template.format(subtasks=partial_text)
        prompt_start_time = time.monotonic()
        try:
            # Capture the response which includes cost
            response = await llm.process_prompts([prompt_text], model_name=model_name, temperature=0.0)
            prompt_end_time = time.monotonic()
            _, cost = response[0] # Unpack cost from the response tuple

            total_time_for_task += (prompt_end_time - prompt_start_time)
            total_cost_for_task += cost # Accumulate cost
            num_successful_calls += 1
        except Exception as e:
            print(f"Error processing task {task.get('id', 'N/A')} subtask {i} with {model_name}: {e}. Skipping task.")
            return None # Indicate failure for the whole task

    # If we finished the loop without errors, return the results
    return total_time_for_task, total_cost_for_task, num_successful_calls

# --- Main Evaluation Logic ---

async def main():
    script_start_time = time.monotonic()
    DATA_PATH = "../data/combined/test_data.json" # Changed data path
    MAX_CONCURRENT = 50 # Limit concurrency to avoid overwhelming APIs
    NUM_SAMPLES = 50 # Number of tasks to randomly sample
    RANDOM_SEED = 42 # For reproducible sampling

    print(f"Loading data from: {DATA_PATH}")
    if not os.path.exists(DATA_PATH):
        print(f"Error: Data file not found at {DATA_PATH}")
        return

    try:
        data_df = pd.read_json(DATA_PATH)
    except Exception as e:
        print(f"Error reading data file: {e}")
        return

    # Randomly sample tasks if the dataset is large enough
    if len(data_df) > NUM_SAMPLES:
        print(f"Randomly sampling {NUM_SAMPLES} tasks from {len(data_df)} total tasks.")
        data_df = data_df.sample(n=NUM_SAMPLES, random_state=RANDOM_SEED)
    elif len(data_df) > 0:
        print(f"Using all {len(data_df)} tasks (less than or equal to {NUM_SAMPLES}).")
    else:
        print("Error: No tasks found in the data file.")
        return

    # Prepare tasks list (we use 'decomposition' field based on 1_to_k script)
    tasks = []
    for _, row in data_df.iterrows():
         tasks.append({
             'id': row.get('id', 'N/A'), # Include id if available
             'task_text': row['decomposition']
         })

    print(f"Loaded {len(tasks)} tasks to evaluate.")

    all_results_summary = [] # Stores summary results per configuration
    sem = asyncio.Semaphore(MAX_CONCURRENT)

    for model_name, prompt_key in CONFIGURATIONS:
        if prompt_key not in PROMPT_DICT:
            print(f"Warning: Prompt key '{prompt_key}' not found in PROMPT_DICT. Skipping configuration.")
            continue

        prompt_template = PROMPT_DICT[prompt_key]
        config_start_time = time.monotonic()
        print(f"\\n--- Evaluating: Model='{model_name}', Prompt='{prompt_key}' ---")

        coroutines = [
            run_with_semaphore(
                sem,
                measure_task_prompt_latency(task, model_name, prompt_template) # Use the updated function
            ) for task in tasks
        ]

        # results_per_task will be a list of tuples (total_time, total_cost, num_calls) or None
        results_per_task = await asyncio.gather(*coroutines)

        # Aggregate time, cost, and calls across all successful tasks for this config
        total_time_config = 0.0
        total_cost_config = 0.0 # Added cost aggregation
        total_calls_config = 0
        num_successful_tasks = 0
        num_failed_tasks = 0
        
        # Store individual prompt latencies and costs for calculating standard error
        all_prompt_latencies = []
        all_prompt_costs = []

        for result in results_per_task:
            if result is not None:
                task_time, task_cost, task_calls = result # Unpack cost
                total_time_config += task_time
                total_cost_config += task_cost # Aggregate cost
                total_calls_config += task_calls
                
                # Store individual prompt latencies and costs
                if task_calls > 0:
                    prompt_latency = task_time / task_calls
                    prompt_cost = task_cost / task_calls
                    # Add each prompt's values to our lists
                    for _ in range(task_calls):
                        all_prompt_latencies.append(prompt_latency)
                        all_prompt_costs.append(prompt_cost)
                        
                num_successful_tasks += 1
            else:
                num_failed_tasks += 1

        average_prompt_latency = 0.0
        average_prompt_cost = 0.0 # Added average cost calculation
        latency_stderr = 0.0
        cost_stderr = 0.0
        
        if total_calls_config > 0:
            average_prompt_latency = total_time_config / total_calls_config
            average_prompt_cost = total_cost_config / total_calls_config # Calculate average cost
            
            # Calculate standard error for latency and cost
            if len(all_prompt_latencies) > 1:
                latency_stderr = np.std(all_prompt_latencies, ddof=1) / np.sqrt(len(all_prompt_latencies))
                cost_stderr = np.std(all_prompt_costs, ddof=1) / np.sqrt(len(all_prompt_costs))
                
            print(f"Average latency per prompt call: {average_prompt_latency:.4f} ± {latency_stderr:.4f} seconds")
            print(f"Average cost per prompt call: ${average_prompt_cost:.6f} ± ${cost_stderr:.6f}") # Print average cost
            print(f"Total successful prompt calls: {total_calls_config}")
            print(f"Total time for successful calls: {total_time_config:.2f} seconds")
            print(f"Total cost for config: ${total_cost_config:.4f}") # Print total cost
        else:
            print("No successful prompt calls for this configuration.")

        print(f"Successful tasks: {num_successful_tasks}")
        if num_failed_tasks > 0:
            print(f"Failed tasks (due to errors): {num_failed_tasks}")

        # Store the aggregated result for this configuration
        all_results_summary.append({
            "Model": model_name,
            "Prompt": prompt_key,
            "Avg Latency per Prompt (s)": average_prompt_latency,
            "Latency StdErr (s)": latency_stderr,
            "Avg Cost per Prompt ($)": average_prompt_cost, # Added avg cost
            "Cost StdErr ($)": cost_stderr,
            "Total Prompt Calls": total_calls_config,
            "Total Config Cost ($)": total_cost_config, # Added total cost
            "Total Tasks Evaluated": len(tasks),
            "Total Time took to evaluate all tasks": total_time_config,
            "Successful Tasks": num_successful_tasks,
            "Failed Tasks": num_failed_tasks
        })

        config_end_time = time.monotonic()
        print(f"Time for this configuration: {config_end_time - config_start_time:.2f} seconds")


    print("\\n--- Summary ---")
    results_df = pd.DataFrame(all_results_summary) # Use the aggregated results
    print(results_df.to_string(index=False))

    # Optional: Save results to CSV
    results_df.to_csv("latency_analysis_results.csv", index=False)
    print("\\nResults saved to latency_analysis_results.csv")

    script_end_time = time.monotonic()
    print(f"\\nTotal script execution time: {script_end_time - script_start_time:.2f} seconds")


if __name__ == "__main__":
    asyncio.run(main())
