#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Wrapper script to run PostgreSQL evaluation instances in parallel using multiple threads.
This prevents database connection issues or OOM issues from terminating the entire evaluation.
Includes synchronization mechanisms to prevent database template conflicts.
"""

import argparse
import json
import os
import sys
import subprocess
import tempfile
import time
import gc
import concurrent.futures
import threading
from datetime import datetime
from tqdm import tqdm
from postgresql_utils import (
    load_jsonl,
    save_report_and_status,
    generate_category_report,
    create_ephemeral_db_copies,
    drop_ephemeral_dbs,
)
from logger import configure_logger


db_template_locks = {}

template_locks_lock = threading.Lock()


def get_db_lock(db_name):

    with template_locks_lock:

        base_db_name = db_name.split("_process_")[0]
        template_db_name = f"{base_db_name}_template"

        if template_db_name not in db_template_locks:
            db_template_locks[template_db_name] = threading.Lock()

        return db_template_locks[template_db_name]


def run_instance(instance_data, instance_id, args, idx):
    """Run a single evaluation instance in a separate process"""

    db_name = instance_data.get("db_id", "")
    if not db_name:
        print(f"Warning: Instance {instance_id} has no db_id specified.")
        db_name = "unknown_db"

    with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as tmp:
        tmp_input = tmp.name
        json.dump(instance_data, tmp)

    tmp_output = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name

    log_dir = os.path.dirname(os.path.abspath(args.pred_file))
    instance_log_file = os.path.join(log_dir, f"instance_{instance_id}.log")

    cmd = [
        "python",
        os.path.join(os.path.dirname(__file__), "single_instance_eval_postgresql.py"),
        "--jsonl_file",
        tmp_input,
        "--output_file",
        tmp_output,
        "--mode",
        args.mode,
        "--logging",
        args.logging,
        "--log_file",
        instance_log_file,
    ]

    db_lock = get_db_lock(db_name)

    print(f"[Thread {idx}] Running instance {instance_id} with database {db_name}...")

    with db_lock:
        print(
            f"[Thread {idx}] Acquired lock for database {db_name}, running process..."
        )
        try:
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                check=False,
                timeout=300,
            )
            success = result.returncode == 0
            if not success:
                print(
                    f"[Thread {idx}] Instance {instance_id} process returned error code {result.returncode}"
                )
                print(f"[Thread {idx}] STDOUT: {result.stdout[:500]}...")
                print(f"[Thread {idx}] STDERR: {result.stderr[:500]}...")
        except subprocess.TimeoutExpired:
            print(f"[Thread {idx}] Instance {instance_id} timed out after 300 seconds")
            success = False

        time.sleep(1)

    print(f"[Thread {idx}] Released lock for database {db_name}, processing results...")

    if success and os.path.exists(tmp_output) and os.path.getsize(tmp_output) > 0:
        try:
            with open(tmp_output, "r") as f:
                evaluation_result = json.load(f)

                evaluation_result["instance_id"] = instance_id

                os.unlink(tmp_input)
                os.unlink(tmp_output)
                return evaluation_result
        except Exception as e:
            print(
                f"[Thread {idx}] Error reading output for instance {instance_id}: {e}"
            )

    try:
        os.unlink(tmp_input)
        if os.path.exists(tmp_output):
            os.unlink(tmp_output)
    except:
        pass

    return {
        "instance_id": instance_id,
        "status": "failed",
        "error_message": "Failed to evaluate instance (process error)",
        "total_test_cases": len(instance_data.get("test_cases", [])),
        "passed_test_cases": 0,
        "failed_test_cases": [],
        "error_sql_error": 0,
        "error_phase_unexpected_pass": 0,
        "original_schema": None,
        "preprocess_schema": None,
        "evaluation_phase_execution_error": True,
        "evaluation_phase_timeout_error": False,
        "evaluation_phase_assertion_error": False,
        "execution_result": None,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Wrapper script to run PostgreSQL evaluation cases using multiple threads."
    )
    parser.add_argument(
        "--pred_file",
        required=True,
        help="Path to the pred JSONL file containing the dataset instances.",
    )
    parser.add_argument(
        "--gt_file",
        required=True,
        help="Path to the JSONL file containing the dataset instances.",
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Limit the number of instances to process.",
    )
    parser.add_argument(
        "--skip",
        type=int,
        default=0,
        help="Skip the first N instances.",
    )
    parser.add_argument(
        "--num_threads",
        type=int,
        default=16,
        help="Number of instances to process in parallel.",
    )
    parser.add_argument(
        "--logging",
        type=str,
        default="false",
        help="Enable or disable per-instance logging ('true' or 'false').",
    )
    parser.add_argument(
        "--mode",
        choices=["gold", "pred"],
        default="gold",
        help="Which field to use for solution SQL (gold or pred).",
    )
    parser.add_argument(
        "--report",
        type=str,
        default="true",
        help="If 'true', generates an additional difficulty-level performance report.",
    )

    args = parser.parse_args()

    gt_data_list = load_jsonl(args.gt_file)
    pred_data_list = load_jsonl(args.pred_file)

    pred_map = {item["instance_id"]: item for item in pred_data_list}

    data_list = []
    for gt_item in gt_data_list:
        inst_id = gt_item.get("instance_id")
        pred_item = pred_map.get(inst_id)

        if pred_item is None:
            print(
                f"Warning: no prediction found for instance_id '{inst_id}'",
                file=sys.stderr,
            )

        merged = {**gt_item}
        if pred_item:
            merged.update(pred_item)

        data_list.append(merged)

    if not data_list:
        print("No data found in the JSONL file.")
        sys.exit(1)

    if args.skip > 0:
        data_list = data_list[args.skip :]
    if args.limit is not None:
        data_list = data_list[: args.limit]

    base_output_folder = os.path.splitext(args.pred_file)[0]
    log_filename = f"{base_output_folder}_wrapper.log"
    logger = configure_logger(log_filename)
    logger.info(
        f"=== Starting PostgreSQL Evaluation via Wrapper Script (Multithreaded with DB locking) ==="
    )
    logger.info(
        f"Processing {len(data_list)} instances from {args.pred_file} using {args.num_threads} threads"
    )

    num_threads = max(1, min(args.num_threads, len(data_list)))

    db_groups = {}
    for i, data in enumerate(data_list):
        db_name = data.get("db_id", "unknown")
        if db_name not in db_groups:
            db_groups[db_name] = []
        db_groups[db_name].append((i, data))

    ordered_instances = []
    while any(len(group) > 0 for group in db_groups.values()):
        for db_name in list(db_groups.keys()):
            if db_groups[db_name]:
                ordered_instances.append(db_groups[db_name].pop(0))

    results_dict = {}

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:

        future_to_instance = {
            executor.submit(
                run_instance,
                data,
                data.get("instance_id", f"instance_{original_idx}"),
                args,
                thread_idx,
            ): (original_idx, data.get("instance_id", f"instance_{original_idx}"))
            for thread_idx, (original_idx, data) in enumerate(ordered_instances)
        }

        for future in tqdm(
            concurrent.futures.as_completed(future_to_instance),
            desc="Evaluating instances",
            total=len(data_list),
        ):
            original_idx, instance_id = future_to_instance[future]
            try:
                result = future.result()

                results_dict[instance_id] = result
            except Exception as e:
                logger.error(f"Error processing instance {instance_id}: {e}")

                error_result = {
                    "instance_id": instance_id,
                    "status": "failed",
                    "error_message": f"Error in wrapper: {str(e)}",
                    "total_test_cases": len(
                        data_list[original_idx].get("test_cases", [])
                    ),
                    "passed_test_cases": 0,
                    "failed_test_cases": [],
                    "error_sql_error": 0,
                    "error_phase_unexpected_pass": 0,
                    "evaluation_phase_execution_error": True,
                    "evaluation_phase_timeout_error": False,
                    "evaluation_phase_assertion_error": False,
                    "original_schema": None,
                    "preprocess_schema": None,
                    "execution_result": None,
                }
                results_dict[instance_id] = error_result

            gc.collect()

    results = []
    for data in data_list:
        instance_id = data.get("instance_id", f"instance_{data_list.index(data)}")
        if instance_id in results_dict:
            results.append(results_dict[instance_id])
        else:

            logger.error(f"Missing result for instance {instance_id}")
            results.append(
                {
                    "instance_id": instance_id,
                    "status": "failed",
                    "error_message": "Result missing from processing",
                    "evaluation_phase_execution_error": True,
                    "total_test_cases": len(data.get("test_cases", [])),
                    "passed_test_cases": 0,
                    "failed_test_cases": [],
                    "execution_result": None,
                }
            )

    number_of_execution_errors = sum(
        1 for r in results if r.get("evaluation_phase_execution_error", False)
    )
    number_of_timeouts = sum(
        1 for r in results if r.get("evaluation_phase_timeout_error", False)
    )
    number_of_assertion_errors = sum(
        1 for r in results if r.get("evaluation_phase_assertion_error", False)
    )
    number_error_unexpected_pass = sum(
        1 for r in results if r.get("error_phase_unexpected_pass", 0) == 1
    )
    total_passed_instances = sum(1 for r in results if r.get("status") == "success")

    total_instances = len(results)
    total_errors = (
        number_of_execution_errors + number_of_timeouts + number_of_assertion_errors
    )
    overall_accuracy = (
        ((total_instances - total_errors) / total_instances * 100)
        if total_instances > 0
        else 0.0
    )
    timestamp = datetime.now().isoformat(sep=" ", timespec="microseconds")

    report_file_path = f"{base_output_folder}_report.txt"
    save_report_and_status(
        report_file_path,
        results,
        data_list,
        number_of_execution_errors,
        number_of_timeouts,
        number_of_assertion_errors,
        overall_accuracy,
        timestamp,
        logger,
        number_error_unexpected_pass,
    )

    print("Overall report generated:", report_file_path)

    output_jsonl_file = f"{base_output_folder}_output_with_status.jsonl"
    with open(output_jsonl_file, "w") as f:
        for data in data_list:
            instance_id = data.get("instance_id")
            for result in results:
                if data.get("instance_id") == result.get("instance_id"):
                    temp_data = {}
                    temp_data["instance_id"] = result.get("instance_id")
                    temp_data["status"] = result["status"]
                    temp_data["execution_result"] = result.get("execution_result")
                    temp_data["error_message"] = result.get("error_message")

                    f.write(json.dumps(temp_data, ensure_ascii=False) + "\n")
                    break

    if args.report == "true":
        model_name = (
            args.pred_file.split("/")[-1]
            .replace(".jsonl", "")
            .replace("_final_output", "")
        )
        generate_category_report(
            results,
            data_list,
            report_file_path,
            logger,
            model_name=model_name,
            metric_name="Test Case",
        )
        print(f"Difficulty report generated: {report_file_path}")

    print("\nEvaluation Summary:")
    print(f"Total instances: {total_instances}")
    print(f"Passed instances: {total_passed_instances}")
    print(f"Failed instances: {total_instances - total_passed_instances}")
    print(f"Execution errors: {number_of_execution_errors}")
    print(f"Timeouts: {number_of_timeouts}")
    print(f"Assertion errors: {number_of_assertion_errors}")
    print(f"Error phase unexpected passes: {number_error_unexpected_pass}")
    print(f"Overall accuracy: {overall_accuracy:.2f}%")

    logger.info(
        "=== PostgreSQL Evaluation via Wrapper Script (Multithreaded with DB locking) Completed ==="
    )


if __name__ == "__main__":
    main()
