import json
import re
import string
import logging
import concurrent.futures
import threading
import argparse
import os
import traceback
from openai import OpenAI
from tqdm import tqdm
from collections import Counter
import csv
from datetime import datetime
from pathlib import Path
import random

from ia_rag.pipeline.interact_workflow import WorkflowInteract
from ia_rag.pipeline.e2e_interact import AgenticInteract

from ia_rag.utils.log_helper import calculate_f1_score, extract_answer_tag, normalize_text


DEFAULT_API_KEY = "xxx"
DEFAULT_AGENT_BASE_URL = "http://localhost:8001/v1"

DEFAULT_TEST_FILE_LIST = [
    "2wiki_test.jsonl",
    "musique_test.jsonl",
    "popqa_test.jsonl",
    "hotpot_test.jsonl",
    "nq_test.jsonl",
    "bamboogle_test.jsonl"
]

ROOT_DIR = Path(__file__).resolve().parents[1]
DEFAULT_OUTPUT_DIR = os.path.join(ROOT_DIR, "res")
DEFAULT_DATASET_DIR = os.path.join(ROOT_DIR, "data/test")
DEFAULT_SAMPLE_NUMS = 3

DEFAULT_MAX_WORKERS = 24
DEFAULT_STRATEGY = "xxx"
AGENT_STRATEGIES = {
    "e2e_interact": AgenticInteract,
    "workflow_interact": WorkflowInteract
}

MODEL_NAME_DICT = {
    "e2e_interact": "xxx",
    "workflow_interact": "xxx"
}


# Define the header for the summary CSV file
CSV_HEADER = [
    "strategy", "model_name", "dataset_name", "record_time",
    "em_accuracy", "avg_f1_score", "avg_iterations", "total_questions"
]


def append_summary_to_csv(summary_data: dict, csv_filepath: str, lock: threading.Lock):
    """
    Appends a summary dictionary to the specified CSV file in a thread-safe manner.
    Creates the file and writes the header if it doesn't exist.
    """
    with lock:
        file_exists = os.path.isfile(csv_filepath)
        try:
            with open(csv_filepath, 'a', newline='', encoding='utf-8') as f:
                writer = csv.DictWriter(f, fieldnames=CSV_HEADER)
                if not file_exists:
                    writer.writeheader()
                writer.writerow(summary_data)
        except IOError as e:
            logging.error(
                f"Could not write to summary CSV file {csv_filepath}: {e}")


logging.basicConfig(level=logging.WARNING)


def process_item(item: dict, agent_instance: object, jsonl_writer_lock: threading.Lock, output_file: str, strategy_name: str) -> tuple[bool, float, int]:
    question = item.get("question")
    golden_answers = item.get("golden_answers", [])

    is_em = False
    f1_score = 0.0
    iterations = 0

    if not question or not golden_answers:
        return is_em, f1_score, iterations

    normalized_golden_answers = [normalize_text(
        ans) for ans in golden_answers if ans]
    ground_truth_for_judge = " / ".join(golden_answers)

    try:
        agent_result, iterations = agent_instance.run(question=question)

        error_message = ""

        if isinstance(agent_result, list):
            predicted_answer_raw = ""
            for msg in reversed(agent_result):
                if msg.get('role') == 'assistant' and msg.get('content'):
                    predicted_answer_raw = msg.get('content', '').strip()
                    break
        elif isinstance(agent_result, str):
            predicted_answer_raw = agent_result
        else:
            raise TypeError(
                f"Agent returned an unexpected type: {type(agent_result)}")

        predicted_answer_tagged = extract_answer_tag(predicted_answer_raw)

        if predicted_answer_tagged:
            normalized_prediction = normalize_text(predicted_answer_tagged)
            if normalized_prediction in normalized_golden_answers:
                is_em = True
            f1_scores = [calculate_f1_score(
                normalized_prediction, gold_ans) for gold_ans in normalized_golden_answers]
            f1_score = max(f1_scores) if f1_scores else 0.0
        else:
            normalized_prediction = ""
            predicted_answer_tagged = ""

    except Exception as e:
        logging.error(
            f"An error occurred while processing Q_ID '{item.get('id')}': {e}")
        traceback.print_exc()
        predicted_answer_tagged = "ERROR!"
        normalized_prediction = ""
        error_message = str(e)
        agent_result = "ERROR!"
        is_em = False
        f1_score = 0.0

    log_entry = {
        "q_id": item.get('id'),
        "question": question,
        "normalized_golden_answers": normalized_golden_answers,
        "normalized_prediction": normalized_prediction,
        "is_em": is_em,
        "f1_score": f1_score,
        "iterations": iterations,
        "error_msg": error_message,
        "agent_output": agent_result,
        "predicted_answer_tagged": predicted_answer_tagged,
    }

    with jsonl_writer_lock:
        with open(output_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')

    return is_em, f1_score, iterations


def run_evaluation(args, sample_index: int, dataset_path: str, csv_filepath: str, csv_writer_lock: threading.Lock):
    strategy_name = args.strategy
    if strategy_name not in AGENT_STRATEGIES:
        print(
            f"FATAL: Unknown strategy '{strategy_name}'. Available: {list(AGENT_STRATEGIES.keys())}")
        return

    agent_model_name = args.agent_model_name if args.agent_model_name else MODEL_NAME_DICT[
        strategy_name]

    AgentClass = AGENT_STRATEGIES[strategy_name]
    ModelNameShort = agent_model_name.split('/')[-1]

    dataset_name = os.path.splitext(os.path.basename(dataset_path))[0]
    output_dir_for_dataset = os.path.join(
        args.output_dir, dataset_name.split('_')[0])
    os.makedirs(output_dir_for_dataset, exist_ok=True)

    jsonl_output_file = os.path.join(output_dir_for_dataset,
                                     f"p{sample_index}_{strategy_name}_{ModelNameShort}.jsonl")

    try:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = [json.loads(line) for line in f]
    except Exception as e:
        print(f"FATAL: Failed to load dataset from {dataset_path}: {e}")
        return

    print(
        f"Successfully loaded {len(dataset)} questions from {dataset_path}")

    tasks_to_run = list(dataset)
    existing_results = []

    if args.reuse:
        print(
            f"Reuse mode is ON. Checking for existing results in {jsonl_output_file}...")
        try:
            if os.path.exists(jsonl_output_file):
                with open(jsonl_output_file, 'r', encoding='utf-8') as f:
                    for idx, line in enumerate(f):
                        if line.strip():
                            existing_results.append(json.loads(line))

            if existing_results:
                processed_q_ids = {res['q_id'] for res in existing_results}
                tasks_to_run = [item for item in dataset if item.get(
                    'id') not in processed_q_ids]
                print(
                    f"Found {len(processed_q_ids)} completed items. Reusing results and skipping them.")
                print(f"Will process {len(tasks_to_run)} new items.")

        except (json.JSONDecodeError, IOError) as e:
            print(
                f"Warning: Could not read or parse output file '{jsonl_output_file}'. Starting fresh for this run. Error: {e}")
            existing_results = []
            tasks_to_run = list(dataset)

    else:
        if os.path.exists(jsonl_output_file):
            open(jsonl_output_file, 'w').close()
        print("Reuse mode is OFF. Starting a fresh run.")

    if not tasks_to_run and existing_results:
        print("All questions have already been processed. Calculating stats from existing file.")

    logging.info(
        f"Initializing agent with strategy: '{strategy_name}' using model '{agent_model_name}'")
    try:
        agent_instance = AgentClass(
            llm_api_base=args.agent_base_url,
            model_name=agent_model_name,
            api_key=args.api_key
        )
    except Exception as e:
        logging.error(f"FATAL: Failed to initialize agent: {e}")
        traceback.print_exc()
        return

    print(
        f"Results for sample run {sample_index} on {dataset_name} will be saved to {jsonl_output_file}")
    print(
        f"Running evaluation with up to {args.max_workers} concurrent workers...")

    exact_matches = sum(1 for res in existing_results if res.get('is_em'))
    total_f1 = sum(res.get('f1_score', 0.0) for res in existing_results)
    total_iterations = sum(res.get('iterations', 0)
                           for res in existing_results)

    total_questions = len(dataset)
    jsonl_writer_lock = threading.Lock()

    if tasks_to_run:
        with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
            future_to_item = {
                executor.submit(process_item, item, agent_instance, jsonl_writer_lock, jsonl_output_file, strategy_name): item
                for item in tasks_to_run
            }
            progress_bar = tqdm(concurrent.futures.as_completed(
                future_to_item), total=len(tasks_to_run), desc=f"Evaluating ({dataset_name} - Run {sample_index})")

            for future in progress_bar:
                try:
                    is_em, f1, iterations = future.result()
                    if is_em:
                        exact_matches += 1
                    total_f1 += f1
                    total_iterations += iterations
                except Exception as e:
                    item = future_to_item[future]
                    logging.error(
                        f"FATAL: A future raised an exception for Q_ID '{item.get('id')}': {e}")
                    traceback.print_exc()

    em_accuracy = (exact_matches / total_questions) * \
        100 if total_questions > 0 else 0
    avg_f1 = (total_f1 / total_questions) if total_questions > 0 else 0
    avg_iterations = (total_iterations /
                      total_questions) if total_questions > 0 else 0

    result_summary = (
        f"\n--- Evaluation Complete for Dataset: '{dataset_name}', Strategy: '{strategy_name}' (Sample Run {sample_index}) ---\n"
        f"Model Used: {ModelNameShort}\n"
        f"Total Questions Evaluated: {total_questions}\n"
        f"Exact Matches (EM): {exact_matches}\n\n"
        f"Exact Match (EM) Accuracy: {em_accuracy:.2f}%\n"
        f"Average F1 Score: {avg_f1:.4f}\n"
        f"Average Iterations per Question: {avg_iterations:.2f}\n"
        f"---------------------------\n"
        f"Detailed results saved to: {jsonl_output_file}\n"
    )
    print(result_summary)

    summary_data = {
        "strategy": strategy_name,
        "model_name": ModelNameShort,
        "dataset_name": dataset_name,
        "record_time": datetime.now().strftime("%Y-%m-%d %H:%M"),
        "em_accuracy": f"{em_accuracy:.2f}",
        "avg_f1_score": f"{avg_f1:.4f}",
        "avg_iterations": f"{avg_iterations:.2f}",
        "total_questions": total_questions
    }
    append_summary_to_csv(summary_data, csv_filepath, csv_writer_lock)


def main():
    parser = argparse.ArgumentParser(
        description="Run evaluation for Agentic RAG strategies on a predefined list of datasets.")

    parser.add_argument("--strategy", type=str,
                        default=DEFAULT_STRATEGY,
                        help=f"The evaluation strategy to use. Available: {list(AGENT_STRATEGIES.keys())}")
    parser.add_argument('--no-reuse', dest='reuse', action='store_false',
                        help="Start a fresh run, overwriting the output file instead of reusing existing results.")
    parser.set_defaults(reuse=True)
    parser.add_argument("--agent-base-url", type=str,
                        default=DEFAULT_AGENT_BASE_URL)
    parser.add_argument("--agent-model-name", type=str, default=None,
                        help="Override the default model for the selected strategy.")
    parser.add_argument("--num-samples", type=int, default=DEFAULT_SAMPLE_NUMS,
                        help="Number of times to run the evaluation for each dataset.")
    parser.add_argument("--output-dir", type=str,
                        default=DEFAULT_OUTPUT_DIR,
                        help="The base directory where results will be saved in subfolders named after datasets.")
    parser.add_argument("--api-key", type=str, default=DEFAULT_API_KEY)
    parser.add_argument("--dataset-dir", type=str,
                        default=DEFAULT_DATASET_DIR,
                        help="The directory containing .jsonl dataset files to evaluate.")
    parser.add_argument("--max-workers", type=int, default=DEFAULT_MAX_WORKERS)

    args = parser.parse_args()

    # Ensure the main output directory exists
    os.makedirs(args.output_dir, exist_ok=True)
    # Define the path for the summary CSV and create a lock for it
    summary_csv_path = os.path.join(args.output_dir, "evaluation_summary.csv")
    csv_writer_lock = threading.Lock()
    print(f"\nEvaluation summary will be appended to: {summary_csv_path}")

    if not os.path.isdir(args.dataset_dir):
        print(
            f"FATAL: The provided dataset path is not a directory: {args.dataset_dir}")
        return

    dataset_files = []
    print(f"Looking for specified test files in directory: {args.dataset_dir}")
    for filename in DEFAULT_TEST_FILE_LIST:
        full_path = os.path.join(args.dataset_dir, filename)
        if os.path.isfile(full_path):
            dataset_files.append(full_path)
            print(f"  [FOUND] {filename}")
        else:
            print(
                f"  [WARNING] Specified test file '{filename}' not found. Skipping.")

    if not dataset_files:
        print(
            f"FATAL: None of the specified test files from DEFAULT_TEST_FILE_LIST were found in the directory: {args.dataset_dir}")
        return

    print(
        f"Found {len(dataset_files)} dataset(s) to evaluate: {', '.join([os.path.basename(p) for p in dataset_files])}")

    for dataset_path in dataset_files:
        dataset_name = os.path.basename(dataset_path)
        print(f"\n{'='*20} PROCESSING DATASET: {dataset_name} {'='*20}")
        for i in range(1, 1 + args.num_samples):
            print(
                f"\n--- Starting sample run {i}/{args.num_samples} for {dataset_name} ---")
            # Pass the CSV file path and the lock to the evaluation function
            run_evaluation(args, sample_index=i, dataset_path=dataset_path,
                           csv_filepath=summary_csv_path, csv_writer_lock=csv_writer_lock)

    print(f"\n{'='*20} ALL EVALUATIONS COMPLETE {'='*20}")


if __name__ == "__main__":
    main()
