import argparse
import json
import logging
import os
import sys
import shutil
import time
import subprocess
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

                                                                      
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from scripts.utils.benchmark_utils import setup_logger
from config.constants import BASE_PROJECT_DIR

                       
                                           
NOISE_SCRIPT_PATH = BASE_PROJECT_DIR / "scripts" / "inject_label_noise.py"
BENCHMARK_SCRIPT_PATH = BASE_PROJECT_DIR / "scripts" / "run_fortress_benchmark.py"
ACTIVE_DB_PATH = BASE_PROJECT_DIR / "data" / "07_vector_db" / "gemma3_1b_exp_noise_exper"
CLEAN_DB_BACKUP_PATH = Path(f"/tmp/fortress_clean_db_backup_{datetime.now().strftime('%Y%m%d%H%M%S')}")

                                     
DATASET_CONFIG = {
                                                                                             
    "xstest": BASE_PROJECT_DIR / "data/05_stitched/xstest_english.csv",
                                                                                                          
                                                                             
}
                              

def run_and_parse_benchmark(
    dataset_path: Path,
    output_dir: Path,
    fortress_run_id_prefix: str,
    run_id_suffix: str,
    logger: logging.Logger
) -> Optional[Dict[str, float]]:
    """
    Runs the external benchmark script and parses its JSON output.
    Returns a dictionary of metrics or None on failure.
    """
    command = [
        sys.executable,
        str(BENCHMARK_SCRIPT_PATH),
        "--input-csvs", str(dataset_path),
        "--output-dir", str(output_dir),
        "--run-id", run_id_suffix,
        "--run-name-prefix", fortress_run_id_prefix,
        "--log-level", "WARNING",                                     
    ]

    logger.info(f"Executing command: {' '.join(command)}")
    try:
        subprocess.run(command, check=True, capture_output=True, text=True, cwd=BASE_PROJECT_DIR)
    except subprocess.CalledProcessError as e:
        logger.error(f"Benchmark script for {dataset_path.name} failed with return code {e.returncode}.")
        logger.error(f"STDOUT: {e.stdout}")
        logger.error(f"STDERR: {e.stderr}")
        return None

                                                                              
                                                                        
    results_data_dir = output_dir / "results_data"
    glob_pattern = f"{fortress_run_id_prefix}_{run_id_suffix}*.json"

    if not results_data_dir.is_dir():
        logger.error(f"Results directory '{results_data_dir}' was not created by the benchmark script.")
        return None

    result_files = list(results_data_dir.glob(glob_pattern))
    if not result_files:
        logger.error(f"Expected output file not found in '{results_data_dir}' with pattern '{glob_pattern}'")
        return None

    latest_file = max(result_files, key=lambda p: p.stat().st_mtime)
    logger.info(f"Parsing results from '{latest_file.name}'")

    try:
        with open(latest_file, 'r') as f:
            data = json.load(f)

        metrics = data.get("metrics", {})
        return {
            "f1_unsafe": metrics.get("f1_unsafe"),
            "accuracy": metrics.get("accuracy"),
            "fpr_unsafe": metrics.get("fpr_unsafe"),
            "fnr_unsafe": metrics.get("fnr_unsafe"),
        }
    except (KeyError, json.JSONDecodeError) as e:
        logger.error(f"Failed to parse or find keys in output file {latest_file}: {e}")
        return None


def generate_plots(results_df: pd.DataFrame, output_dir: Path, num_runs: int, logger: logging.Logger):
    """Generates and saves a plot summarizing the noise robustness experiment."""
    if results_df.empty:
        logger.warning("Results DataFrame is empty. Skipping plot generation.")
        return

    logger.info("Generating final plots...")
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(8, 6))

    sns.lineplot(
        data=results_df,
        x='noise_level',
        y='f1_unsafe',
        hue='dataset_name',
        marker='o',
        ax=ax,
        errorbar='sd',
        err_style="band"
    )

    ax.set_xlabel("Label Noise Level")
    ax.set_ylabel("F1 Score (Unsafe Class)")
    ax.set_title("System Performance vs. Label Noise")
    ax.legend(title="Benchmark Dataset", fontsize='small')
    ax.grid(True)
    ax.set_ylim(bottom=max(0, results_df['f1_unsafe'].min() - 0.1), top=1.05)
    ax.set_xlim(left=0)

    plt.tight_layout()
    figure_path = output_dir / "noise_robustness_figure.pdf"
    plt.savefig(figure_path, format='pdf', bbox_inches='tight')
    plt.close(fig)
    logger.info(f"Saved robustness plot to {figure_path}")


def main():
    parser = argparse.ArgumentParser(description="Run FORTRESS Noise Robustness Experiment.")
    parser.add_argument(
        "--noise-levels", type=float, nargs='+',
        default=[0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50],
        help="Space-separated list of noise levels to test."
    )
    parser.add_argument("--num-runs", type=int, default=5, help="Number of runs per noise level.")
    parser.add_argument(
        "--output-dir", type=str,
        default=f"benchmarks/noise_experiment_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
        help="Directory to save plots and results."
    )
    parser.add_argument(
        "--fortress-run-id", type=str, default="fortress_gemma_1b_expanded",
        help="The name of the fortress config to use as a foundation for run IDs."
    )
    parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    
    args = parser.parse_args()

    output_dir = BASE_PROJECT_DIR / args.output_dir
    raw_json_dir = output_dir / "raw_json"
    raw_json_dir.mkdir(parents=True, exist_ok=True)
    logger = setup_logger(__name__, level_str=args.log_level)
    
    logger.info("--- FORTRESS Noise Robustness Experiment ---")
    logger.info(f"Arguments: {vars(args)}")

    if not ACTIVE_DB_PATH.exists():
        logger.error(f"Clean database not found at '{ACTIVE_DB_PATH}'. Please ensure it exists before running.")
        sys.exit(1)

    logger.info(f"Backing up clean database from '{ACTIVE_DB_PATH}' to '{CLEAN_DB_BACKUP_PATH}'...")
    if CLEAN_DB_BACKUP_PATH.exists():
        shutil.rmtree(CLEAN_DB_BACKUP_PATH)
    shutil.copytree(ACTIVE_DB_PATH, CLEAN_DB_BACKUP_PATH)
    logger.info("Backup complete.")

    all_experiment_results = []
    try:
        for noise_level in args.noise_levels:
            for run_num in range(1, args.num_runs + 1):
                logger.info(f"--- Starting Run: Noise Level={noise_level}, Run={run_num}/{args.num_runs} ---")

                logger.info("[Step 1/3] Restoring clean database...")
                if ACTIVE_DB_PATH.exists():
                    shutil.rmtree(ACTIVE_DB_PATH)
                shutil.copytree(CLEAN_DB_BACKUP_PATH, ACTIVE_DB_PATH)

                logger.info(f"[Step 2/3] Injecting {noise_level:.2%} noise into labels...")
                noise_command = [
                    sys.executable, str(NOISE_SCRIPT_PATH),
                    "--noise-level", str(noise_level),
                    "--seed", str(run_num)                                             
                ]
                subprocess.run(noise_command, check=True, capture_output=True, text=True)

                logger.info("[Step 3/3] Benchmarking all datasets against the noisy database...")
                for dataset_name, dataset_path in DATASET_CONFIG.items():
                    run_id_suffix = f"{dataset_name}_noise_{noise_level}_run_{run_num}"
                    
                    metrics = run_and_parse_benchmark(
                        dataset_path, raw_json_dir, args.fortress_run_id, run_id_suffix, logger
                    )
                    
                    if metrics and all(v is not None for v in metrics.values()):
                        result_row = {
                            "dataset_name": dataset_name,
                            "noise_level": noise_level,
                            "run_number": run_num,
                            **metrics
                        }
                        all_experiment_results.append(result_row)
                        logger.info(f"Result for {dataset_name}: F1={metrics['f1_unsafe']:.4f}, Acc={metrics['accuracy']:.4f}")
                    else:
                        logger.error(f"Failed to get benchmark results for {dataset_name}.")
                        all_experiment_results.append({
                            "dataset_name": dataset_name, "noise_level": noise_level, "run_number": run_num,
                            "f1_unsafe": np.nan, "accuracy": np.nan, "fpr_unsafe": np.nan, "fnr_unsafe": np.nan
                        })
                
                                           
                pd.DataFrame(all_experiment_results).to_csv(output_dir / "noise_robustness_live.csv", index=False)

        logger.info("Experiment loops complete. Finalizing results...")
        results_df = pd.DataFrame(all_experiment_results)
        final_csv_path = output_dir / "noise_robustness_summary.csv"
        results_df.to_csv(final_csv_path, index=False)
        logger.info(f"Final results saved to {final_csv_path}")

        generate_plots(results_df, output_dir, args.num_runs, logger)

    except Exception as e:
        logger.critical(f"An unhandled exception occurred: {e}", exc_info=True)
    finally:
        logger.info(f"Cleaning up backup database at '{CLEAN_DB_BACKUP_PATH}'...")
        if CLEAN_DB_BACKUP_PATH.exists():
            shutil.rmtree(CLEAN_DB_BACKUP_PATH)
            logger.info("Backup removed successfully.")
    
    logger.info("--- Experiment Finished ---")


if __name__ == "__main__":
    main()
