import os
import sys
import json
import shutil
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Any, Set, Optional
from collections import defaultdict
import re
import urllib.request
import urllib.error
import time
import csv
import math

from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from rich.table import Table
from rich.rule import Rule
from rich.prompt import Prompt

from prompt_toolkit.shortcuts import checkboxlist_dialog
from prompt_toolkit.styles import Style
from prompt_toolkit.application import Application
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.layout import Layout
from prompt_toolkit.layout.containers import HSplit, VSplit, Window
from prompt_toolkit.layout.controls import FormattedTextControl
from prompt_toolkit.widgets import Frame
from prompt_toolkit.formatted_text import FormattedText
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from config.constants import (
    BENCHMARK_CSVS,
    BASE_PROJECT_DIR,
    BASE_OUTPUT_DIR,
    REPORTS_DIR,
    RESULTS_DATA_DIR,
    DISCORD_WEBHOOK_URL,
    FORTRESS_CONFIGS,
    FORTRESS_SETTINGS_PATH,
    FORTRESS_EXPERIMENT_CATALOGUE,
    FORTRESS_DEFAULT_CONFIG,
)




PYTHON_EXE = sys.executable

                                
                                                              

                                                  
MODELS_TO_RUN: Dict[str, str] = {}

                             
for fortress_name, (run_id, config_file) in FORTRESS_CONFIGS.items():
    MODELS_TO_RUN[fortress_name] = f'{PYTHON_EXE} {{script_dir}}/run_fortress_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --run-id "{run_id}"'

                  
MODELS_TO_RUN.update({
    "WildGuard (7B)": f'{PYTHON_EXE} {{script_dir}}/run_wildguard_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}"',
    "OpenAI Moderation": f'{PYTHON_EXE} {{script_dir}}/run_openai_moderation_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}"',
    "Llama Guard 3-1B": f'{PYTHON_EXE} {{script_dir}}/run_llama_guard_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-id "meta-llama/Llama-Guard-3-1B"',
    "Llama Guard 3-8B": f'{PYTHON_EXE} {{script_dir}}/run_llama_guard_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-id "meta-llama/Llama-Guard-3-8B"',                                                                                                                  
    "ShieldGemma 2B": f'{PYTHON_EXE} {{script_dir}}/run_shieldgemma_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-id "google/shieldgemma-2b"',
    "ShieldGemma 2-4B": f'{PYTHON_EXE} {{script_dir}}/run_shieldgemma2_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}"',
    "ShieldGemma 9B": f'{PYTHON_EXE} {{script_dir}}/run_shieldgemma_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-id "google/shieldgemma-9b"',
    "AegisGuard Defensive": f'{PYTHON_EXE} {{script_dir}}/run_aegis_guard_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --adapter-id "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"',
    "AegisGuard Permissive": f'{PYTHON_EXE} {{script_dir}}/run_aegis_guard_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --adapter-id "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0"',
    "GuardReasoner 1B": f'{PYTHON_EXE} {{script_dir}}/run_guardreasoner_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-size "1B"',
    "GuardReasoner 3B": f'{PYTHON_EXE} {{script_dir}}/run_guardreasoner_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-size "3B"',
    "GuardReasoner 8B": f'{PYTHON_EXE} {{script_dir}}/run_guardreasoner_benchmark.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}" --model-size "8B"',
    "Ayub XGboost Classifier": f'{PYTHON_EXE} {{script_dir}}/run_ayub_oai_xgb_guard_classifier.py --input-csvs "{{input_csv}}" --output-dir "{{output_dir}}"',

})

                                         
MODEL_NAME_PATTERNS = {
    "AegisGuard Defensive": "aegisguard_aegis_ai_content_safety_llamaguard_defensive",
    "AegisGuard Permissive": "aegisguard_aegis_ai_content_safety_llamaguard_permissive",
    "Llama Guard 3-1B": "llamaguard_llama_guard_3_1b",
    "Llama Guard 3-8B": "llamaguard_llama_guard_3_8b",
                                             
    "ShieldGemma 2B": "shieldgemma_shieldgemma_2b",
    "ShieldGemma 9B": "shieldgemma_shieldgemma_9b",
    "SheildGemma2-4B": "shieldgemma2_shieldgemma_2_4b",
    "WildGuard (7B)": "wildguard_wildguard",
    "FORTRESS": "fortress_",
    "OpenAI Moderation": "openai_mod_",
    "Llama Guard (Generic)": "llamaguard",
    "ShieldGemma (Generic)": "shieldgemma",
    "Ayub XGboost Classifier": "ayub_oai_xgb_guard",
}

console = Console()

                                      
dialog_style = Style.from_dict({
    'dialog':             'bg:#303030',
    'dialog frame.label': 'bg:#505050 #ffffff bold',
    'dialog.body':        'bg:#303030 #f0f0f0',
    'dialog shadow':      'bg:#202020',
    'checkbox':           '#d0d0d0',
    'checkbox-selected':  'bg:#005f00 #ffffff bold',
    'checkbox-checked':   '#32cd32',
    'button':             'bg:#003f00 #ffffff',
    'button.focused':     'bg:#007f00 #ffffff bold',
})

                                  

def send_discord_notification(model_name: str, benchmark_name: str, success: bool):
    """Send a Discord webhook notification when a model finishes a benchmark"""
    try:
        status_emoji = "✅" if success else "❌"
        status_text = "completed successfully" if success else "failed"
        
                                                      
        message = {
            "embeds": [{
                "title": f"{status_emoji} Benchmark Completed",
                "description": f"**Model:** {model_name}\n**Benchmark:** {benchmark_name}\n**Status:** {status_text}",
                "color": 0x00ff00 if success else 0xff0000,
            }]
        }
        
        data = json.dumps(message).encode('utf-8')
        headers = {
            'Content-Type': 'application/json',
            'User-Agent': 'Benchmark-Suite-Manager/1.0'
        }
        
        req = urllib.request.Request(DISCORD_WEBHOOK_URL, data=data, headers=headers, method='POST')
        
                                        
        with urllib.request.urlopen(req, timeout=10) as response:
            response_code = response.getcode()
            
                                                               
            if response_code == 204:
                console.print(f"[dim]✅ Discord notification sent for {model_name} on {benchmark_name}[/dim]")
            else:
                console.print(f"[yellow]⚠️ Unexpected response code: {response_code}[/yellow]")
        
    except urllib.error.HTTPError as e:
        error_body = ""
        try:
            error_body = e.read().decode('utf-8')
        except:
            pass
            
        console.print(f"[yellow]Warning: Failed to send Discord notification: HTTP {e.code} - {e.reason}[/yellow]")
        
        if e.code == 403:
            console.print(f"[dim]Hint: Webhook permissions issue or invalid URL[/dim]")
        elif e.code == 429:
            console.print(f"[dim]Hint: Rate limited - Discord is throttling requests[/dim]")
        elif e.code == 400:
            console.print(f"[dim]Hint: Bad request - message format might be invalid[/dim]")
            if error_body:
                console.print(f"[dim]Error details: {error_body}[/dim]")
        
                                                             
        if e.code == 400:
            console.print(f"[dim]Trying fallback with simple message format...[/dim]")
            try:
                simple_message = {
                    "content": f"{status_emoji} **Benchmark Completed**\n**Model:** {model_name}\n**Benchmark:** {benchmark_name}\n**Status:** {status_text}"
                }
                
                fallback_data = json.dumps(simple_message).encode('utf-8')
                fallback_req = urllib.request.Request(DISCORD_WEBHOOK_URL, data=fallback_data, headers=headers, method='POST')
                
                with urllib.request.urlopen(fallback_req, timeout=10) as fallback_response:
                    if fallback_response.getcode() == 204:
                        console.print(f"[dim]✅ Fallback notification sent successfully[/dim]")
                    
            except Exception as fallback_error:
                console.print(f"[dim]Fallback also failed: {fallback_error}[/dim]")
        
    except urllib.error.URLError as e:
        console.print(f"[yellow]Warning: Network error sending Discord notification: {e.reason}[/yellow]")
        console.print(f"[dim]Hint: Check your internet connection[/dim]")
        
    except TimeoutError:
        console.print(f"[yellow]Warning: Discord notification timed out[/yellow]")
        console.print(f"[dim]Hint: Discord may be experiencing issues[/dim]")
        
    except Exception as e:
        console.print(f"[yellow]Warning: Unexpected error sending Discord notification: {e}[/yellow]")
        console.print(f"[dim]Error type: {type(e).__name__}[/dim]")


                                        

def backup_settings_yaml():
    """Create a backup of the current settings.yaml file"""
    if FORTRESS_SETTINGS_PATH.exists():
        backup_path = FORTRESS_SETTINGS_PATH.with_suffix('.yaml.backup')
        shutil.copy2(FORTRESS_SETTINGS_PATH, backup_path)
        console.print(f"[dim]Backed up current settings.yaml to {backup_path}[/dim]")
        return backup_path
    return None

def restore_settings_yaml(backup_path: Optional[Path] = None):
    """Restore settings.yaml from backup or default"""
    if backup_path and backup_path.exists():
        shutil.copy2(backup_path, FORTRESS_SETTINGS_PATH)
        console.print(f"[dim]Restored settings.yaml from backup[/dim]")
        backup_path.unlink()                                   
    elif FORTRESS_DEFAULT_CONFIG.exists():
        shutil.copy2(FORTRESS_DEFAULT_CONFIG, FORTRESS_SETTINGS_PATH)
        console.print(f"[dim]Restored settings.yaml from default config[/dim]")
    else:
        console.print(f"[yellow]Warning: Could not restore settings.yaml (no backup or default found)[/yellow]")

def swap_fortress_config(config_file_name: str) -> bool:
    """Swap the settings.yaml with the specified config file"""
    config_path = FORTRESS_EXPERIMENT_CATALOGUE / config_file_name
    
    if not config_path.exists():
        console.print(f"[red]Error: Config file not found: {config_path}[/red]")
        return False
    
    try:
        shutil.copy2(config_path, FORTRESS_SETTINGS_PATH)
        console.print(f"[dim]Swapped settings.yaml with {config_file_name}[/dim]")
        return True
    except Exception as e:
        console.print(f"[red]Error swapping config file: {e}[/red]")
        return False

                                            

class BenchmarkGrid:
    def __init__(self, models: List[Tuple[str, Any]], benchmarks: List[Tuple[str, Any]]):
        self.models = models
        self.benchmarks = benchmarks
        self.selected_cells: Set[Tuple[int, int]] = set()
        self.current_row = 0
        self.current_col = 0
        
                                                     
        for i in range(len(models)):
            for j in range(len(benchmarks)):
                self.selected_cells.add((i, j))
    
    def toggle_current(self):
        """Toggle the currently selected cell"""
        cell = (self.current_row, self.current_col)
        if cell in self.selected_cells:
            self.selected_cells.remove(cell)
        else:
            self.selected_cells.add(cell)
    
    def get_selected_combinations(self) -> List[Tuple[Tuple[str, Any], Tuple[str, Any]]]:
        """Return selected (model, benchmark) combinations"""
        combinations = []
        for i, j in sorted(self.selected_cells):
            combinations.append((self.models[i], self.benchmarks[j]))
        return combinations
    
    def get_formatted_text(self) -> FormattedText:
        """Return formatted text for the grid display"""
        result = []
        
                                 
        model_col_width = 25
        benchmark_col_width = 20
        
                                         
        result.append(('', ' ' * (model_col_width + 3)))
        for j, (benchmark_name, _) in enumerate(self.benchmarks):
            truncated = benchmark_name[:benchmark_col_width-2]
            result.append(('class:header', f"{truncated:^{benchmark_col_width}}"))
            result.append(('', '   '))
        result.append(('', '\n\n'))
        
                    
        for i, (model_name, _) in enumerate(self.models):
                        
            truncated_model = model_name[:model_col_width-1]
            if self.current_row == i:
                result.append(('class:highlight', f"{truncated_model:<{model_col_width}}"))
            else:
                result.append(('class:model', f"{truncated_model:<{model_col_width}}"))
            result.append(('', '   '))
            
                   
            for j in range(len(self.benchmarks)):
                is_selected = (i, j) in self.selected_cells
                is_current = self.current_row == i and self.current_col == j
                
                                                             
                cell_display = '✓' if is_selected else ' '
                
                if is_current:
                    result.append(('class:current', f"  [{cell_display}]  "))
                else:
                    if is_selected:
                        result.append(('class:checked', f"   {cell_display}   "))
                    else:
                        result.append(('', f"   {cell_display}   "))
                
                                             
                result.append(('', ' ' * (benchmark_col_width - 7) + '   '))
            
            result.append(('', '\n\n'))
        
                      
        result.append(('', '\n'))
        result.append(('class:instruction', '↑↓←→ Navigate    SPACE Toggle    ENTER Confirm    ESC Cancel'))
        
        return FormattedText(result)

def show_grid_selection(models: List[Tuple[str, Any]], benchmarks: List[Tuple[str, Any]]) -> List[Tuple[Tuple[str, Any], Tuple[str, Any]]]:
    """Show the grid selection dialog and return selected combinations"""
    grid = BenchmarkGrid(models, benchmarks)
    
                  
    kb = KeyBindings()
    
    @kb.add('up')
    def move_up(event):
        if grid.current_row > 0:
            grid.current_row -= 1
    
    @kb.add('down')
    def move_down(event):
        if grid.current_row < len(grid.models) - 1:
            grid.current_row += 1
    
    @kb.add('left')
    def move_left(event):
        if grid.current_col > 0:
            grid.current_col -= 1
    
    @kb.add('right')
    def move_right(event):
        if grid.current_col < len(grid.benchmarks) - 1:
            grid.current_col += 1
    
    @kb.add(' ')
    def toggle_cell(event):
        grid.toggle_current()
    
    @kb.add('enter')
    def confirm(event):
        event.app.exit(result=grid.get_selected_combinations())
    
    @kb.add('escape')
    def cancel(event):
        event.app.exit(result=None)
    
                        
    grid_app_style = Style.from_dict({
        'current': 'bg:#555555 #ffffff',
        'checked': '#00ff00',
        'header': '#ffff00 bold',
        'model': '#00ffff',
        'highlight': 'bg:#333333 #00ffff',
        'instruction': '#888888',
    })
    
                            
    control = FormattedTextControl(
        lambda: grid.get_formatted_text(),
        show_cursor=False,
        focusable=True,
    )
    
    window = Window(control, style='bg:#1a1a1a #ffffff')
    
    dialog = Frame(
        body=window,
        title="Select Model-Benchmark Combinations",
    )
    
    app = Application(
        layout=Layout(dialog),
        key_bindings=kb,
        style=grid_app_style,
        full_screen=False,
    )
    
    return app.run()

def select_items(options_map: Dict[str, Any], item_type_name: str, selection_prompt_message: str) -> List[Tuple[str, Any]]:
    """
    Displays a list of items and prompts the user to select one or more using checkboxes.
    Returns a list of tuples: (lookup_name, value).
    """
    if not options_map:
        console.print(f"[yellow]No {item_type_name} configured.[/yellow]")
        return []

    console.print(Panel(f"Select {item_type_name}", title=selection_prompt_message, style="bold blue", expand=False))

    options_list = list(options_map.items())
    checkbox_items = [(name, name) for name, _ in options_list]

    dialog_title = f"Choose {item_type_name}"
    dialog_text = (
        f"Use ARROW KEYS to navigate, SPACE to toggle selection. "
        f"Press ENTER to confirm, ESC to cancel."
    )

    selected_names = checkboxlist_dialog(
        title=dialog_title,
        text=dialog_text,
        values=checkbox_items,
        style=dialog_style,
    ).run()

    selected_options = []
    if selected_names:
        console.print("[bold green]You have selected:[/bold green]")
        for name in selected_names:
            original_value = options_map[name]
            console.print(f"- {name}")
            selected_options.append((name, original_value))
    else:
        console.print("[yellow]No items selected or selection cancelled.[/yellow]")

    return selected_options

def run_command(command_template: str, input_csv_path: str, csv_lookup_name: str, model_name: str):
    """
    Formats and runs a benchmark command directly in the current shell session.
    For FORTRESS models, handles config swapping.
    """
    script_dir = BASE_PROJECT_DIR / "scripts"
    command = command_template.format(script_dir=str(script_dir), input_csv=input_csv_path, output_dir=str(BASE_OUTPUT_DIR))

    console.print(f"[cyan]Running model [bold]{model_name}[/bold] on benchmark [bold]{csv_lookup_name}[/bold]...[/cyan]")
    console.print(f"[dim]$ {command}[/dim]")
    console.print(Rule())

                                     
    backup_path = None
    if model_name in FORTRESS_CONFIGS:
        _, config_file = FORTRESS_CONFIGS[model_name]
        backup_path = backup_settings_yaml()
        if not swap_fortress_config(config_file):
            console.print(f"[red]Failed to swap config for {model_name}. Skipping.[/red]")
            if backup_path:
                restore_settings_yaml(backup_path)
            return

    original_cwd = os.getcwd()
    success = False
    try:
        os.chdir(str(BASE_PROJECT_DIR))
        return_code = os.system(command)
        
        if return_code == 0:
            console.print(f"\n[green]Successfully completed model [bold]{model_name}[/bold] for [bold]{csv_lookup_name}[/bold].[/green]")
            success = True
        else:
            console.print(f"\n[bold red]Error running model [bold]{model_name}[/bold] for [bold]{csv_lookup_name}[/bold]. Return code: {return_code}[/bold red]")
            
    except Exception as e:
        console.print(f"[bold red]An unexpected error occurred while running command for {model_name} on {csv_lookup_name}: {e}[/bold red]")
        console.print(f"[dim]Command: {command}[/dim]")
    finally:
        os.chdir(original_cwd)
                                                   
        if model_name in FORTRESS_CONFIGS:
            restore_settings_yaml(backup_path)
        
                                   
        send_discord_notification(model_name, csv_lookup_name, success)

def organize_outputs(csv_lookup_name: str):
    """
    Moves benchmark output files from the root of REPORTS_DIR and RESULTS_DATA_DIR
    into subdirectories named after csv_lookup_name.
    """
    console.print(f"[cyan]Organizing output files for benchmark [bold]{csv_lookup_name}[/bold]...[/cyan]")

    for source_parent_dir in [REPORTS_DIR, RESULTS_DATA_DIR]:
        target_subdir = source_parent_dir / csv_lookup_name
        target_subdir.mkdir(parents=True, exist_ok=True)

        moved_count = 0
        if not source_parent_dir.exists():
            console.print(f"[yellow]Source directory {source_parent_dir} does not exist. Skipping organization for it.[/yellow]")
            continue

        for item in source_parent_dir.iterdir():
            if item.is_file():
                try:
                    destination_path = target_subdir / item.name
                    shutil.move(str(item), str(destination_path))
                    console.print(f"[dim]Moved: [green]{item.name}[/green] to {destination_path}[/dim]")
                    moved_count += 1
                except Exception as e:
                    console.print(f"[red]Error moving file {item.name} to {target_subdir}: {e}[/red]")

        if moved_count == 0:
            console.print(f"[dim]No files found to move in the root of {source_parent_dir} for {csv_lookup_name}.[/dim]")
        else:
            console.print(f"[green]Organized {moved_count} file(s) from the root of {source_parent_dir} into {target_subdir}[/green]")

def run_benchmarks():
    """Main function for running benchmarks"""
    console.print(Panel("[bold magenta]Interactive Benchmark Runner[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

                              
    selected_csv_tuples = select_items(
        BENCHMARK_CSVS,
        "Benchmark CSVs",
        "Select benchmark CSV(s) to run"
    )

    if not selected_csv_tuples:
        console.print("[yellow]No benchmark CSVs selected. Returning to main menu.[/yellow]")
        return
    console.print(Rule())

                      
    selected_model_tuples = select_items(
        MODELS_TO_RUN,
        "Models",
        "Select model(s) to run"
    )

    if not selected_model_tuples:
        console.print("[yellow]No models selected. Returning to main menu.[/yellow]")
        return
    console.print(Rule())

                                                      
    console.print("[bold cyan]Select specific Model-Benchmark combinations to run:[/bold cyan]")
    selected_combinations = show_grid_selection(selected_model_tuples, selected_csv_tuples)
    
    if not selected_combinations:
        console.print("[yellow]No combinations selected or operation cancelled. Returning to main menu.[/yellow]")
        return
    
                     
    console.print(Rule())
    console.print("[bold underline]Selected Combinations:[/bold underline]")
    
                            
    summary_table = Table(show_header=True, header_style="bold magenta")
    summary_table.add_column("Model", style="cyan")
    summary_table.add_column("Benchmark", style="green")
    
    for (model_name, _), (benchmark_name, _) in selected_combinations:
        summary_table.add_row(model_name, benchmark_name)
    
    console.print(summary_table)
    console.print(f"\n[bold]Total combinations to run: {len(selected_combinations)}[/bold]")
    
    console.print(Rule())
    console.print("[bold green]Starting benchmark runs...[/bold green]")
    console.print(Rule())

    REPORTS_DIR.mkdir(parents=True, exist_ok=True)
    RESULTS_DATA_DIR.mkdir(parents=True, exist_ok=True)

                                                             
    benchmark_combinations = {}
    for (model_name, model_cmd), (benchmark_name, benchmark_path) in selected_combinations:
        if benchmark_name not in benchmark_combinations:
            benchmark_combinations[benchmark_name] = []
        benchmark_combinations[benchmark_name].append((model_name, model_cmd, benchmark_path))
    
                               
    for benchmark_name, combinations in benchmark_combinations.items():
        console.rule(f"[bold blue]Processing Benchmark: {benchmark_name}[/bold blue]")
        
        csv_path = combinations[0][2]                                                              
        
        if not Path(csv_path).exists():
            console.print(f"[bold red]Error: CSV file not found for {benchmark_name}: {csv_path}. Skipping this CSV.[/bold red]")
            console.rule(style="blue")
            continue
        
        for model_name, command_template, _ in combinations:
            run_command(command_template, csv_path, benchmark_name, model_name)
        
        organize_outputs(benchmark_name)
        console.rule(style="blue")

    console.print("[bold bright_green]All selected benchmarks and models processed successfully![/bold bright_green]")
    console.print(f"Check {REPORTS_DIR} and {RESULTS_DATA_DIR} for categorized results.")
    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()

                                          

def get_model_name_from_filename(filename_str: str) -> str:
    """Derives a canonical model name from the results filename."""
    fn_lower = Path(filename_str).name.lower()
    
                                       
    for fortress_name, (run_id, _) in FORTRESS_CONFIGS.items():
        if run_id in fn_lower:
            return fortress_name
    
                                
    for name, pattern in MODEL_NAME_PATTERNS.items():
        if pattern in fn_lower:
            return name
    
                                                                                              
    base_name = Path(filename_str).stem
    parts = base_name.split('_')
                                                                               
    if len(parts) > 3:
        if parts[-1] == "results" or parts[-1] == "suite" or parts[-1] == "suite_results":
            parts.pop()
        
                                                                      
        num_numeric_suffix = 0
        temp_parts = list(parts)
        while temp_parts:
            if temp_parts[-1].isdigit() and len(temp_parts[-1]) >= 4:
                num_numeric_suffix += 1
                temp_parts.pop()
            else:
                break
        
        if num_numeric_suffix > 0 and len(parts) - num_numeric_suffix > 0:
            return "_".join(parts[:-num_numeric_suffix])

    return base_name

def extract_fortress_file_suffix(filename: str) -> str:
    """Extract the meaningful part of the filename after 'fortress_'"""
    filename_lower = filename.lower()
    
                                          
    if 'fortress_' in filename_lower:
                                             
        start_pos = filename_lower.find('fortress_') + len('fortress_')
        suffix = filename[start_pos:]
        
                                                               
                                                            
        suffix = re.sub(r'_\d{8}(_\d{6})?', '', suffix)
        
                                
        for pattern in ['_results.json', '_suite_results.json', '_suite.json', '.json']:
            if suffix.endswith(pattern):
                suffix = suffix[:-len(pattern)]
        
                                                
        if len(suffix) > 30:
            suffix = suffix[:27] + "..."
        
        return suffix
    
    return filename

def extract_timestamp_from_filename(filename: str) -> Optional[str]:
    """Extracts the timestamp (YYYYMMDD_HHMMSS) from a result filename if present."""
    match = re.search(r'_(\d{8}_\d{6})', filename)
    if match:
        return match.group(1)
    return None

def calculate_stdev(data: List[float]) -> float:
    """Calculates the population standard deviation of a list of numbers."""
    if len(data) < 2:
        return 0.0
    n = len(data)
    mean = sum(data) / n
    variance = sum((x - mean) ** 2 for x in data) / n
    return math.sqrt(variance)

def view_results():
    """View benchmark results in a consolidated table with performance and latency."""
    console.print(Panel("[bold magenta]Benchmark Performance & Latency Summary[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

                                                                                                                 
    results_store: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict))
    all_benchmarks_discovered: Set[str] = set()
    all_models_discovered: Set[str] = set()

    console.print(f"[bold blue]Scanning for benchmark results in: {RESULTS_DATA_DIR.resolve()}[/bold blue]")

                                                                  
    latest_files: Dict[Tuple[str, str], Path] = {}
    multiple_files_detected: Set[Tuple[str, str]] = set()
    f1_change_notes: Dict[Tuple[str, str], str] = {}
    for benchmark_dir_name in BENCHMARK_CSVS.keys():
        current_benchmark_path = RESULTS_DATA_DIR / benchmark_dir_name
        if not current_benchmark_path.is_dir():
            continue
                                                         
        files_by_model: Dict[str, list] = defaultdict(list)
        for file_path in current_benchmark_path.glob("*.json"):
            if not file_path.is_file():
                continue
            model_name = get_model_name_from_filename(file_path.name)
            files_by_model[model_name].append(file_path)
        for model_name, files in files_by_model.items():
            if not files:
                continue
                                                             
            if len(files) > 1:
                multiple_files_detected.add((benchmark_dir_name, model_name))
            def get_ts_or_mtime(fp):
                ts = extract_timestamp_from_filename(fp.name)
                if ts:
                    return ts
                return str(int(fp.stat().st_mtime))
            files_sorted = sorted(files, key=get_ts_or_mtime, reverse=True)
            latest_files[(benchmark_dir_name, model_name)] = files_sorted[0]
                                          
            if len(files_sorted) > 1:
                try:
                    with open(files_sorted[0], 'r', encoding='utf-8') as f:
                        data_latest = json.load(f)
                    with open(files_sorted[1], 'r', encoding='utf-8') as f:
                        data_prev = json.load(f)
                    f1_latest = float(data_latest.get("metrics", {}).get("f1_unsafe") or 0.0)
                    f1_prev = float(data_prev.get("metrics", {}).get("f1_unsafe") or 0.0)
                    if abs(f1_latest - f1_prev) > 1e-6:
                        f1_change_notes[(benchmark_dir_name, model_name)] = f"F1 changed: {f1_prev:.4f} -> {f1_latest:.4f}"
                    else:
                        f1_change_notes[(benchmark_dir_name, model_name)] = f"F1 unchanged: {f1_latest:.4f}"
                except Exception as e:
                    f1_change_notes[(benchmark_dir_name, model_name)] = f"[Error comparing F1: {e}]"

    for (benchmark_dir_name, model_display_name), file_path in latest_files.items():
        all_benchmarks_discovered.add(benchmark_dir_name)
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if "metrics" not in data or "f1_unsafe" not in data["metrics"]:
                continue
            f1_score = float(data["metrics"]["f1_unsafe"] or 0.0)
            all_models_discovered.add(model_display_name)
                                                     
            latencies_ms = []
            individual_results = data.get("individual_results", [])
            if individual_results and isinstance(individual_results, list):
                for res in individual_results:
                    time_ms = res.get("processing_time_ms")
                    if time_ms is not None:
                        try:
                            latencies_ms.append(float(time_ms))
                        except (ValueError, TypeError):
                            continue
            results_store[benchmark_dir_name][model_display_name] = {
                'f1': f1_score,
                'latencies': latencies_ms
            }
        except (json.JSONDecodeError, KeyError, ValueError) as e:
            console.print(f"[red]Error processing {file_path.name}: {e}. Skipping.[/red]")

                                                                                                        
    if multiple_files_detected:
        for (benchmark_dir_name, model_name) in multiple_files_detected:
            used_file = latest_files[(benchmark_dir_name, model_name)]
            note = f1_change_notes.get((benchmark_dir_name, model_name), "")
            console.print(f"[yellow]Note: Multiple result files found for model [bold]{model_name}[/bold] on benchmark [bold]{benchmark_dir_name}[/bold]. Using latest: [white]{used_file.name}[/white]. {note}[/yellow]")

    if not all_models_discovered:
        console.print("[bold red]No valid model data could be extracted. Check JSON files for 'metrics.f1_unsafe'.[/bold red]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

    sorted_benchmarks = sorted(list(all_benchmarks_discovered))
    sorted_models = sorted(list(all_models_discovered))

                          
    table = Table(title="Model Performance (F1 Unsafe) and Latency (ms)", show_lines=True, expand=True)
    table.add_column("Model", style="cyan", overflow="fold", min_width=25)
    for benchmark_name in sorted_benchmarks:
        table.add_column(benchmark_name, style="magenta", justify="center")
    table.add_column("Avg F1", style="green", justify="center")
    table.add_column("Std. Dev. (F1)", style="green", justify="center")
    table.add_column("Avg Latency (ms)", style="yellow", justify="center")
    table.add_column("Std. Dev. (ms)", style="yellow", justify="center")

                            
    for model_name in sorted_models:
        row_cells = [model_name]
        f1s_for_avg: List[float] = []
        all_model_latencies: List[float] = []
        for benchmark_name in sorted_benchmarks:
            model_data = results_store[benchmark_name].get(model_name)
            if model_data:
                f1 = model_data['f1']
                latencies = model_data['latencies']
                f1s_for_avg.append(f1)
                all_model_latencies.extend(latencies)
                avg_latency_bm = (sum(latencies) / len(latencies)) if latencies else 0.0
                cell_text = f"{f1 * 100:.1f}\n[dim]({avg_latency_bm:.1f}ms)[/dim]"
                row_cells.append(cell_text)
            else:
                row_cells.append("-")
                                                   
        avg_f1 = (sum(f1s_for_avg) / len(f1s_for_avg) * 100) if f1s_for_avg else 0.0
        stdev_f1 = calculate_stdev([f * 100 for f in f1s_for_avg]) if f1s_for_avg else 0.0
        avg_latency = (sum(all_model_latencies) / len(all_model_latencies)) if all_model_latencies else 0.0
        stdev_latency = calculate_stdev(all_model_latencies)
        row_cells.append(f"{avg_f1:.1f}")
        row_cells.append(f"{stdev_f1:.2f}")
        row_cells.append(f"{avg_latency:.2f}")
        row_cells.append(f"{stdev_latency:.2f}")
        table.add_row(*row_cells)

    console.print(table)
    console.print("\n[dim]Notes:[/dim]")
    console.print("[dim]- Each benchmark cell shows: F1 Unsafe Score % (Average Latency in ms).[/dim]")
    console.print("[dim]- 'Avg F1' is the mean F1 score for a model across all benchmarks it ran on.[/dim]")
    console.print("[dim]- 'Avg Latency' and 'Std. Dev.' are calculated across all individual prompts for that model.[/dim]")
    console.print("[dim]- Only the latest run per model-benchmark pair is shown, based on the trailing timestamp in the filename.[/dim]")
    
    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()

                                             

def count_benchmark_entries(csv_path: str) -> int:
    """Counts entries in a CSV where split == 'benchmark'."""
    count = 0
    try:
        with open(csv_path, 'r', newline='', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            if 'split' not in reader.fieldnames:
                console.print(f"[red]Error: 'split' column not found in {csv_path}. Cannot count benchmark entries.[/red]")
                return 0
            for row in reader:
                if row.get('split', '').lower() == 'benchmark':
                    count += 1
    except FileNotFoundError:
        console.print(f"[red]Error: CSV file not found at {csv_path}.[/red]")
        return 0
    except Exception as e:
        console.print(f"[red]Error reading CSV file {csv_path}: {e}[/red]")
        return 0
    return count

def view_delay_report():
    """
    Scans existing benchmark JSON files to calculate and display the average
    delay (latency) per query for each model, presenting the results in a table.
    """
    console.print(Panel("[bold magenta]Model Delay/Query Report[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

                                                                                          
    delay_data: Dict[str, Dict[str, Tuple[float, str]]] = defaultdict(dict)
    all_benchmarks_discovered: Set[str] = set()
    all_models_discovered: Set[str] = set()

    console.print(f"[bold blue]Scanning for benchmark results in: {RESULTS_DATA_DIR.resolve()}[/bold blue]")

                                                                                            
    latest_files: Dict[Tuple[str, str], Path] = {}
    for benchmark_dir_name in BENCHMARK_CSVS.keys():
        current_benchmark_path = RESULTS_DATA_DIR / benchmark_dir_name
        if not current_benchmark_path.is_dir():
            continue

                                                                                                     
        for file_path in sorted(current_benchmark_path.glob("*.json")):
            model_name = get_model_name_from_filename(file_path.name)
            latest_files[(benchmark_dir_name, model_name)] = file_path

                                         
    for (benchmark_dir_name, model_name), file_path in latest_files.items():
        all_benchmarks_discovered.add(benchmark_dir_name)
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)

            individual_results = data.get("individual_results")
            if not individual_results or not isinstance(individual_results, list):
                                                       
                continue

            processing_times = []
            for result in individual_results:
                time_ms = result.get("processing_time_ms")
                if time_ms is not None:
                    try:
                                                            
                        processing_times.append(float(time_ms))
                    except (ValueError, TypeError):
                        continue                               
            
            if processing_times:
                                                       
                avg_delay = sum(processing_times) / len(processing_times)
                all_models_discovered.add(model_name)
                delay_data[benchmark_dir_name][model_name] = (avg_delay, file_path.name)

        except json.JSONDecodeError:
            console.print(f"[yellow]Warning: Could not decode JSON from {file_path}. Skipping.[/yellow]")
        except Exception as e:
            console.print(f"[red]An unexpected error occurred processing {file_path}: {e}. Skipping.[/red]")

    if not all_models_discovered:
        console.print("[bold red]No model data with 'individual_results' and 'processing_time_ms' found.[/bold red]")
        console.print("[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                                                            
    sorted_benchmarks = sorted(list(all_benchmarks_discovered))
    sorted_models = sorted(list(all_models_discovered), key=lambda m: (not m.startswith("FORTRESS"), m))

                                                                        
    model_averages = {}
    for model_name in sorted_models:
        delays = [delay_data[b].get(model_name, (0.0,))[0] for b in sorted_benchmarks if model_name in delay_data[b]]
        if delays:
            model_averages[model_name] = sum(delays) / len(delays)

    best_avg_model = min(model_averages, key=model_averages.get) if model_averages else None
    
    benchmark_bests = {}
    for benchmark_name in sorted_benchmarks:
        all_delays = [info[0] for model, info in delay_data[benchmark_name].items()]
        if all_delays:
            benchmark_bests[benchmark_name] = min(all_delays)

                                       
    table = Table(title="Model Delay Comparison (ms/query)", show_lines=True, expand=True)
    table.add_column("Model", style="cyan", overflow="fold", min_width=25, ratio=1)
    for benchmark_name in sorted_benchmarks:
        table.add_column(benchmark_name, style="magenta", justify="center", ratio=1)
    table.add_column("Average", style="yellow", justify="center", ratio=1)

    for model_name in sorted_models:
        row_cells = [model_name]
        delays_for_avg: List[float] = []

        for benchmark_name in sorted_benchmarks:
            score_info = delay_data[benchmark_name].get(model_name)
            if score_info:
                delay, _ = score_info
                cell_text = f"{delay:.2f}"
                
                                                                 
                best_delay = benchmark_bests.get(benchmark_name)
                if best_delay is not None and abs(delay - best_delay) < 1e-6:
                     cell_text = f"[bold green]{cell_text}[/bold green]"

                row_cells.append(cell_text)
                delays_for_avg.append(delay)
            else:
                row_cells.append("-")

                                                         
        if delays_for_avg:
            avg_delay = sum(delays_for_avg) / len(delays_for_avg)
            avg_cell = f"{avg_delay:.2f}"
            if best_avg_model == model_name:
                avg_cell = f"[bold green]{avg_cell}[/bold green]"
        else:
            avg_cell = "-"
        
        row_cells.append(avg_cell)
        table.add_row(*row_cells)

    if not table.rows:
        console.print("[yellow]No data to display. Ensure JSON files contain 'individual_results' with 'processing_time_ms'.[/yellow]")
    else:
        console.print(table)

    console.print("\n[dim]Notes:[/dim]")
    console.print("[dim]- Values represent the average delay per query in milliseconds (ms).[/dim]")
    console.print("[dim]- Data is sourced from 'processing_time_ms' in the latest run for each model-benchmark pair.[/dim]")
    console.print("[dim]- The lowest (best) delay in each column and the best average are highlighted in [bold green]bold green[/bold green].[/dim]")
    
    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()





def export_full_results_csv():
    """Scans all benchmark results and exports them to a single comprehensive CSV file."""
    console.print(Panel("[bold magenta]Export Full Benchmark Results to CSV[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

    all_results_data: List[Dict[str, Any]] = []
    
    console.print(f"[bold blue]Scanning for benchmark results in: {RESULTS_DATA_DIR.resolve()}[/bold blue]")

                                                         
    for benchmark_name in BENCHMARK_CSVS.keys():
        benchmark_path = RESULTS_DATA_DIR / benchmark_name
        if not benchmark_path.is_dir():
            continue

        console.print(f"[dim]Processing benchmark: {benchmark_name}...[/dim]")

                                                     
        for file_path in benchmark_path.glob("*.json"):
            if not file_path.is_file():
                continue
            
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)

                                         
                metrics = data.get("metrics")
                if not metrics:
                    console.print(f"[yellow]Warning: No 'metrics' object in {file_path.name}. Skipping.[/yellow]")
                    continue

                num_samples = metrics.get("num_samples")
                duration_s = data.get("duration_seconds")
                
                                                            
                latency_ms = 0
                if duration_s is not None and num_samples is not None and num_samples > 0:
                    latency_ms = (duration_s * 1000) / num_samples

                                                        
                cm = metrics.get("confusion_matrix_values", {})

                                          
                row_data = {
                    "model_name": get_model_name_from_filename(file_path.name),
                    "benchmark_name": benchmark_name,
                    "f1_unsafe": metrics.get("f1_unsafe"),
                    "accuracy": metrics.get("accuracy"),
                    "precision_unsafe": metrics.get("precision_unsafe"),
                    "recall_unsafe": metrics.get("recall_unsafe"),
                    "fpr_unsafe": metrics.get("fpr_unsafe"),
                    "fnr_unsafe": metrics.get("fnr_unsafe"),
                    "latency_ms_per_entry": f"{latency_ms:.2f}" if latency_ms else None,
                    "total_duration_s": duration_s,
                    "num_samples": num_samples,
                    "tp": cm.get("TP"),
                    "tn": cm.get("TN"),
                    "fp": cm.get("FP"),
                    "fn": cm.get("FN"),
                    "source_file": file_path.name,
                }
                all_results_data.append(row_data)

            except json.JSONDecodeError:
                console.print(f"[yellow]Warning: Could not decode JSON from {file_path}. Skipping.[/yellow]")
            except Exception as e:
                console.print(f"[red]An unexpected error occurred processing {file_path}: {e}. Skipping.[/red]")

    if not all_results_data:
        console.print("[bold red]No valid result files were found to export.[/bold red]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                            
    output_filename = "full_benchmark_results.csv"
    output_path = BASE_OUTPUT_DIR / output_filename
    
    try:
        console.print(Rule())
        console.print(f"[cyan]Writing {len(all_results_data)} records to CSV...[/cyan]")
        
                                 
        header = [
            "model_name", "benchmark_name", "f1_unsafe", "accuracy", 
            "precision_unsafe", "recall_unsafe", "fpr_unsafe", "fnr_unsafe",
            "latency_ms_per_entry", "total_duration_s", "num_samples", 
            
            "tp", "tn", "fp", "fn", "source_file"
        ]

        with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=header)
            writer.writeheader()
            writer.writerows(all_results_data)

        console.print(Panel(
            f"[bold green]Successfully exported all results to:[/bold green]\n[white]{output_path.resolve()}[/white]",
            title="Export Complete",
            border_style="green"
        ))

    except Exception as e:
        console.print(f"[bold red]Failed to write to CSV file: {e}[/bold red]")

    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()

def view_xssafety_language_performance():
    """Calculates and displays per-language F1 scores for all models on the XSafety benchmark."""
    console.print(Panel("[bold magenta]XSafety Per-Language Performance Viewer[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

    benchmark_name = "xsafety_multilingual"
    xssafety_results_dir = RESULTS_DATA_DIR / benchmark_name
    
    if not xssafety_results_dir.is_dir():
        console.print(f"[red]Error: Benchmark directory not found: {xssafety_results_dir}[/red]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                                            
                                                                                     
    model_language_stats = defaultdict(lambda: defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0}))
    all_languages = set()
    model_files = {}

    console.print(f"[cyan]Scanning for results in {xssafety_results_dir}...[/cyan]")

                                                                        
    for file_path in sorted(xssafety_results_dir.glob("*.json")):
        model_name = get_model_name_from_filename(file_path.name)
        model_files[model_name] = file_path                                             

    if not model_files:
        console.print(f"[red]No result files found in {xssafety_results_dir}.[/red]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

    lang_code_pattern = re.compile(r"all_languages_matter_(\w+)\.jsonl")

    for model_name, file_path in model_files.items():
        console.print(f"[dim]  -> Processing {model_name} from {file_path.name}[/dim]")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            individual_results = data.get("individual_results")
            if not individual_results:
                console.print(f"[yellow]Warning: 'individual_results' not found in {file_path.name}. Skipping model.[/yellow]")
                continue

            for result in individual_results:
                source_file = result.get("source_file_input", "")
                match = lang_code_pattern.search(source_file)
                if not match:
                    continue
                
                lang_code = match.group(1).upper()
                all_languages.add(lang_code)
                
                true_label = result.get("true_label", "").upper()
                pred_label = result.get("predicted_label", "").upper()

                                                               
                                                                
                if true_label == "UNSAFE" and pred_label == "UNSAFE":
                    model_language_stats[model_name][lang_code]["TP"] += 1
                                                                                    
                elif true_label == "SAFE" and pred_label == "UNSAFE":
                    model_language_stats[model_name][lang_code]["FP"] += 1
                                                                                     
                                                                                     
                elif true_label == "UNSAFE" and pred_label != "UNSAFE":
                    model_language_stats[model_name][lang_code]["FN"] += 1
        
        except (json.JSONDecodeError, KeyError) as e:
            console.print(f"[red]Error processing file {file_path.name}: {e}")

    if not model_language_stats:
        console.print("[red]Could not extract any language data from any of the result files.[/red]")
        console.print("[dim]Ensure 'source_file_input' fields are formatted like 'all_languages_matter_{lang_code}.jsonl'[/dim]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                                  
    model_language_f1 = defaultdict(dict)
    for model_name, lang_stats in model_language_stats.items():
        for lang_code, stats in lang_stats.items():
            tp, fp, fn = stats["TP"], stats["FP"], stats["FN"]
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
            model_language_f1[model_name][lang_code] = f1

                             
    sorted_models = sorted(model_language_f1.keys())
    sorted_languages = sorted(list(all_languages))

    table = Table(
        title="XSafety Multilingual F1 Score Comparison",
        show_lines=True,
        expand=True
    )
    table.add_column("Model", style="cyan", min_width=20)
    for lang_code in sorted_languages:
        table.add_column(lang_code, style="magenta", justify="center")

    for model_name in sorted_models:
        row_cells = [model_name]
        for lang_code in sorted_languages:
            f1_score = model_language_f1[model_name].get(lang_code)
            if f1_score is not None:
                cell_text = f"{f1_score:.3f}"
                                                              
                best_in_lang = max(model_language_f1[m].get(lang_code, 0.0) for m in sorted_models)
                if abs(f1_score - best_in_lang) < 1e-6:
                    cell_text = f"[bold white underline]{cell_text}[/bold white underline]"
                row_cells.append(cell_text)
            else:
                row_cells.append("-")
        table.add_row(*row_cells)

    console.print(table)
    console.print("\n[dim]Notes:[/dim]")
    console.print("[dim]- Table shows the F1 score for the 'UNSAFE' class.[/dim]")
    console.print("[dim]- 'ERROR' predictions for 'UNSAFE' prompts are now correctly counted as False Negatives.[/dim]")
    console.print("[dim]- The best F1 score in each language column is [bold white underline]bold, white, and underlined[/bold white underline].[/dim]")
    console.print("[dim]- Results are based on the latest JSON file found for each model.[/dim]")
    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()


def export_misclassified_prompts():
    """Exports misclassified prompts from selected benchmarks and models to a JSON file."""
    console.print(Panel("[bold magenta]Export Misclassified Prompts[/bold magenta]", expand=False, border_style="magenta"))
    console.print(Rule())

                          
    selected_benchmarks = select_items(
        BENCHMARK_CSVS,
        "Benchmark CSVs",
        "Select benchmark(s) to scan for failures"
    )
    if not selected_benchmarks:
        console.print("[yellow]No benchmarks selected. Returning to main menu.[/yellow]")
        return
    selected_benchmark_names = [name for name, _ in selected_benchmarks]

                                                                 
    model_files = defaultdict(list)
    for benchmark_name in selected_benchmark_names:
        benchmark_path = RESULTS_DATA_DIR / benchmark_name
        if benchmark_path.is_dir():
            for file_path in benchmark_path.glob("*.json"):
                model_name = get_model_name_from_filename(file_path.name)
                                                                      
                model_files[model_name].append((benchmark_name, file_path))

    if not model_files:
        console.print("[red]No result files found for the selected benchmarks.[/red]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                      
    model_map_for_selection = {name: files for name, files in model_files.items()}
    selected_models = select_items(
        model_map_for_selection,
        "Models",
        "Select model(s) to export failures from"
    )
    if not selected_models:
        console.print("[yellow]No models selected. Returning to main menu.[/yellow]")
        return
    
                                                
    failure_cases = []
    console.print(Rule())
    console.print("[cyan]Scanning for misclassified prompts...[/cyan]")

    for model_name, files_list in selected_models:
                                                                   
                                                                                
        files_to_process = [
            (b_name, f_path) for b_name, f_path in files_list 
            if b_name in selected_benchmark_names
        ]

        for benchmark_name, file_path in files_to_process:
            console.print(f"[dim]  -> Processing {model_name} on {benchmark_name} ({file_path.name})[/dim]")
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                
                for result in data.get("individual_results", []):
                    true_label = result.get("true_label")
                    pred_label = result.get("predicted_label")

                                                                                                  
                    if true_label is not None and pred_label is not None and true_label != pred_label:
                        failure_cases.append({
                            "ground_truth": true_label,
                            "predicted_label": pred_label,
                            "original_prompt": result.get("original_prompt"),
                            "justification": result.get("justification_pred")
                        })
            except Exception as e:
                console.print(f"[red]Error reading or processing {file_path}: {e}[/red]")
    
                                     
    if not failure_cases:
        console.print(Rule())
        console.print("[yellow]No misclassified prompts found for the selected models and benchmarks.[/yellow]")
        console.print("\n[dim]Press Enter to return to main menu...[/dim]")
        input()
        return

                                                                     
    model_name_str = "_".join(sorted([name for name, _ in selected_models])).replace(" ", "_").lower()
    timestamp = time.strftime('%Y%m%d_%H%M%S')
    output_filename = f"failure_report_{model_name_str}_{timestamp}.json"
    output_path = BASE_OUTPUT_DIR / output_filename
    
    try:
        console.print(Rule())
        console.print(f"[cyan]Found {len(failure_cases)} misclassified prompts. Writing to file...[/cyan]")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(failure_cases, f, indent=2, ensure_ascii=False)
        
        console.print(Panel(
            f"[bold green]Successfully exported {len(failure_cases)} failure cases to:[/bold green]\n[white]{output_path.resolve()}[/white]",
            title="Export Complete",
            border_style="green"
        ))
    except Exception as e:
        console.print(f"[bold red]Failed to write to JSON file: {e}[/bold red]")

    console.print("\n[dim]Press Enter to return to main menu...[/dim]")
    input()
    
                       

def main_menu():
    """Display main menu and handle user choice"""
    while True:
        console.clear()
        console.print(Panel(
            "[bold magenta]Benchmark Suite Manager[/bold magenta]\n\n"
            "[cyan]A unified tool for running benchmarks and viewing results[/cyan]",
            expand=False,
            border_style="magenta"
        ))
        console.print(Rule())
        
        console.print("[bold]Choose an option:[/bold]\n")
        console.print("  [cyan]1.[/cyan] Run Benchmarks")
        console.print("  [cyan]2.[/cyan] View Consolidated Results")
        console.print("  [cyan]3.[/cyan] View XSafety Per-Language F1 Score")
        console.print("  [cyan]4.[/cyan] Export Misclassified Prompts")
        console.print("  [cyan]5.[/cyan] Export Full Results to CSV")
        console.print("  [cyan]6.[/cyan] Exit\n")
        
        choice = Prompt.ask(
            "[bold]Enter your choice[/bold]",
            choices=["1", "2", "3", "4", "5", "6"],
            default="1"
        )
        
        console.clear()
        
        if choice == "1":
            run_benchmarks()
        elif choice == "2":
            view_results()                                            
        elif choice == "3":
            view_xssafety_language_performance()
        elif choice == "4":
            export_misclassified_prompts()
        elif choice == "5":
            export_full_results_csv()
        elif choice == "6":
            console.print("[bold green]Thank you for using Benchmark Suite Manager![/bold green]")
            break

if __name__ == "__main__":
    main_menu()




