import os
import sys
import json
import yaml
import shutil
import subprocess
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Union, Set
from datetime import datetime
from collections import defaultdict
import pandas as pd
import numpy as np
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.rule import Rule
from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import radiolist_dialog, checkboxlist_dialog
from prompt_toolkit.validation import Validator, ValidationError
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import re

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

               

from config.constants import BENCHMARK_CSVS, BASE_PROJECT_DIR
EXPERIMENTS_BASE_DIR = BASE_PROJECT_DIR / "experiments"
SETTINGS_PATH = BASE_PROJECT_DIR / "config" / "settings.yaml"
EXPERIMENT_CONFIGS_DIR = EXPERIMENTS_BASE_DIR / "configs"

                         
from config.constants import BENCHMARK_CSVS

console = Console()

class NumberValidator(Validator):
    """Validator for numeric input"""
    def validate(self, document):
        text = document.text
        try:
            float(text)
        except ValueError:
            raise ValidationError(
                message='Please enter a valid number',
                cursor_position=len(text))

class ParameterTuningExperiment:
    def __init__(self):
        self.experiments_base_dir = EXPERIMENTS_BASE_DIR
        self.configs_dir = EXPERIMENT_CONFIGS_DIR
        self.setup_directories()
        
    def setup_directories(self):
        """Create necessary directories"""
        self.experiments_base_dir.mkdir(parents=True, exist_ok=True)
        self.configs_dir.mkdir(parents=True, exist_ok=True)
        
    def get_settings_structure(self) -> Dict[str, Any]:
        """Load and return the settings structure"""
        with open(SETTINGS_PATH, 'r') as f:
            return yaml.safe_load(f)
    
    def flatten_dict(self, d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
        """Flatten nested dictionary with dot notation keys"""
        items = []
        for k, v in d.items():
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
            if isinstance(v, dict):
                items.extend(self.flatten_dict(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)
    
    def unflatten_dict(self, d: Dict[str, Any], sep: str = '.') -> Dict[str, Any]:
        """Unflatten dictionary from dot notation"""
        result = {}
        for key, value in d.items():
            parts = key.split(sep)
            current = result
            for part in parts[:-1]:
                if part not in current:
                    current[part] = {}
                current = current[part]
            current[parts[-1]] = value
        return result
    
    def get_value_from_path(self, data: Dict[str, Any], path: str) -> Any:
        """Get value from nested dict using dot notation path"""
        keys = path.split('.')
        value = data
        for key in keys:
            value = value.get(key, {})
        return value
    
    def set_value_from_path(self, data: Dict[str, Any], path: str, value: Any):
        """Set value in nested dict using dot notation path"""
        keys = path.split('.')
        current = data
        for key in keys[:-1]:
            if key not in current:
                current[key] = {}
            current = current[key]
        current[keys[-1]] = value
    
    def load_experiment_configs(self) -> List[Dict[str, Any]]:
        """Load all existing experiment configurations"""
        configs = []
        for config_file in self.configs_dir.glob("*.json"):
            with open(config_file, 'r') as f:
                config = json.load(f)
                config['filename'] = config_file.name
                configs.append(config)
        return configs
    
    def save_experiment_config(self, config: Dict[str, Any]):
        """Save experiment configuration"""
        filename = f"{config['name'].replace(' ', '_').lower()}.json"
        config_path = self.configs_dir / filename
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        console.print(f"[green]✓ Saved experiment configuration to {config_path}[/green]")
    
    def create_new_experiment_config(self) -> Optional[Dict[str, Any]]:
        """Interactive creation of new experiment configuration"""
        console.clear()
        console.print("[bold cyan]Create New Experiment Configuration[/bold cyan]\n")
        
                             
        name = prompt("Experiment name: ")
        description = prompt("Description (optional): ")
        
                                
        settings = self.get_settings_structure()
        flat_settings = self.flatten_dict(settings)
        
                                           
        numeric_params = {k: v for k, v in flat_settings.items() 
                        if isinstance(v, (int, float))}
        
                                                                                       
        excluded_params = set()
        for param in numeric_params:
            if param.endswith("default_perplexity") or param.endswith("mixed_perplexity"):
                excluded_params.add(param)
        
                                        
        selectable_params = {k: v for k, v in numeric_params.items() 
                            if k not in excluded_params}
        
                                   
        console.print("\n[bold]Select parameters to tune:[/bold]")
        param_choices = [(k, f"{k} (current: {v})") for k, v in selectable_params.items()]
        
        selected_params = checkboxlist_dialog(
            title="Parameter Selection",
            text="Select parameters to tune (use space to select, enter to confirm):",
            values=param_choices
        ).run()
        
        if not selected_params:
            console.print("[yellow]No parameters selected[/yellow]")
            return None
        
                                  
        parameters = {}
        params_to_skip = set()                                   
        
        for param in selected_params:
            if param in params_to_skip:
                continue
                
            console.print(f"\n[cyan]Configure parameter: {param}[/cyan]")
            current_value = numeric_params[param]
            console.print(f"Current value: {current_value}")
            
                                                                       
            is_weight = 0 <= current_value <= 1
            
                                     
            if is_weight:
                min_val = float(prompt("Minimum value (0-1): ", validator=NumberValidator()))
                max_val = float(prompt("Maximum value (0-1): ", validator=NumberValidator()))
                increment = float(prompt("Increment: ", validator=NumberValidator()))
            else:
                min_val = float(prompt("Minimum value: ", validator=NumberValidator()))
                max_val = float(prompt("Maximum value: ", validator=NumberValidator()))
                increment = float(prompt("Increment: ", validator=NumberValidator()))
            
                                                  
            if param.endswith("default_primary") or param.endswith("mixed_primary"):
                                                         
                if param.endswith("default_primary"):
                    perplexity_param = param.replace("default_primary", "default_perplexity")
                else:                 
                    perplexity_param = param.replace("mixed_primary", "mixed_perplexity")
                
                                                                                   
                if perplexity_param in numeric_params:
                    console.print(f"[yellow]Note: {perplexity_param} will automatically be set to 1 - {param}[/yellow]")
                    parameters[param] = {
                        "min": min_val,
                        "max": max_val,
                        "increment": increment,
                        "type": "weight_pair",
                        "pair": perplexity_param
                    }
                                                                               
                    params_to_skip.add(perplexity_param)
                else:
                                       
                    parameters[param] = {
                        "min": min_val,
                        "max": max_val,
                        "increment": increment,
                        "type": "range"
                    }
            else:
                                   
                parameters[param] = {
                    "min": min_val,
                    "max": max_val,
                    "increment": increment,
                    "type": "range"
                }
        
                              
        config = {
            "name": name,
            "description": description,
            "parameters": parameters,
            "created": datetime.now().isoformat()
        }
        
                            
        self.save_experiment_config(config)
        
        return config


    def calculate_parameter_combinations(self, config: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Calculate all parameter combinations for an experiment"""
        param_values = {}
        
        for param, settings in config['parameters'].items():
            if settings['type'] == 'weight_pair':
                                                                        
                values = []
                current = settings['min']
                while current <= settings['max']:
                    values.append(round(current, 6))
                    current += settings['increment']
                param_values[param] = values
            else:
                               
                values = []
                current = settings['min']
                while current <= settings['max']:
                    values.append(round(current, 6))
                    current += settings['increment']
                param_values[param] = values
        
                                   
        param_names = list(param_values.keys())
        param_value_lists = [param_values[name] for name in param_names]
        
        combinations = []
        for values in itertools.product(*param_value_lists):
            combo = {}
            for i, param_name in enumerate(param_names):
                combo[param_name] = values[i]
                                     
                if config['parameters'][param_name]['type'] == 'weight_pair':
                    pair_param = config['parameters'][param_name]['pair']
                    combo[pair_param] = round(1.0 - values[i], 6)
            combinations.append(combo)
        
        return combinations
    
    def format_run_id(self, param_values: Dict[str, Any]) -> str:
        """Format parameter values into a run ID string"""
                                                           
        parts = []
        for param_path, value in sorted(param_values.items()):
                                                     
            param_name = param_path.split('.')[-1]
                                            
            if param_name == "default_primary":
                param_name = "dp"
            elif param_name == "default_perplexity":
                param_name = "dpe"
            elif param_name == "mixed_primary":
                param_name = "mp"
            elif param_name == "mixed_perplexity":
                param_name = "mpe"
            elif len(param_name) > 4:
                param_name = param_name[:4]
            
                          
            if isinstance(value, float):
                if 0 <= value <= 1:
                                                        
                    value_str = str(int(value * 100))
                else:
                                                          
                    value_str = f"{value:.2f}".replace('.', '')
            else:
                value_str = str(value)
            
            parts.append(f"{param_name}_{value_str}")
        
        return "_".join(parts)
    
    def organize_outputs(self, benchmark_name: str, experiment_dir: Path):
        """
        Moves benchmark output files from the root of reports/ and results_data/
        into subdirectories organized by benchmark name only.
        """
        console.print(f"[cyan]Organizing output files for benchmark [bold]{benchmark_name}[/bold]...[/cyan]")
        
        reports_dir = experiment_dir / "reports"
        results_data_dir = experiment_dir / "results_data"
        
        for source_parent_dir in [reports_dir, results_data_dir]:
            if not source_parent_dir.exists():
                console.print(f"[yellow]Source directory {source_parent_dir} does not exist. Skipping.[/yellow]")
                continue
            
                                           
            benchmark_subdir = source_parent_dir / benchmark_name
            benchmark_subdir.mkdir(parents=True, exist_ok=True)
            
            moved_count = 0
                                                  
            for item in source_parent_dir.iterdir():
                if item.is_file():
                    try:
                                                                                                            
                        destination_path = benchmark_subdir / item.name
                        shutil.move(str(item), str(destination_path))
                        console.print(f"[dim]Moved: [green]{item.name}[/green] to {destination_path.relative_to(experiment_dir)}[/dim]")
                        moved_count += 1
                    except Exception as e:
                        console.print(f"[red]Error moving file {item.name}: {e}[/red]")
            
            if moved_count == 0:
                console.print(f"[dim]No files found to move in {source_parent_dir.relative_to(experiment_dir)}[/dim]")
            else:
                console.print(f"[green]Organized {moved_count} file(s) in {source_parent_dir.relative_to(experiment_dir)}[/green]")


    def run_experiment(self, config: Dict[str, Any], benchmarks: List[str]):
        """Run parameter tuning experiment"""
        experiment_name = config['name'].replace(' ', '_').lower()
        experiment_dir = self.experiments_base_dir / experiment_name
        
                            
        experiment_dir.mkdir(parents=True, exist_ok=True)
        (experiment_dir / "reports").mkdir(parents=True, exist_ok=True)
        (experiment_dir / "results_data").mkdir(parents=True, exist_ok=True)
        
                         
        settings_backup = SETTINGS_PATH.parent / f"settings.yaml.{experiment_name}_backup"
        shutil.copy2(SETTINGS_PATH, settings_backup)
        
        try:
                                
            base_settings = self.get_settings_structure()
            
                                    
            combinations = self.calculate_parameter_combinations(config)
            total_combinations = len(combinations)
            
            console.print(Panel(
                f"[bold cyan]Running Experiment: {config['name']}[/bold cyan]\n\n"
                f"Total parameter combinations: {total_combinations}\n"
                f"Selected benchmarks: {len(benchmarks)}",
                expand=False
            ))
            
                             
            all_results = {}
            
                                                        
            for benchmark_name in benchmarks:
                if benchmark_name not in BENCHMARK_CSVS:
                    console.print(f"[yellow]Benchmark {benchmark_name} not found, skipping[/yellow]")
                    continue
                
                console.rule(f"[bold blue]Processing Benchmark: {benchmark_name}[/bold blue]")
                
                with Progress(
                    SpinnerColumn(),
                    TextColumn("[progress.description]{task.description}"),
                    BarColumn(),
                    TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
                    console=console
                ) as progress:
                    
                    task = progress.add_task(
                        f"[cyan]Running {benchmark_name}...", 
                        total=total_combinations
                    )
                    
                    for i, param_values in enumerate(combinations):
                                                                       
                        current_settings = yaml.safe_load(yaml.dump(base_settings))
                        
                                                
                        for param_path, value in param_values.items():
                            self.set_value_from_path(current_settings, param_path, float(value))
                        
                                                
                        with open(SETTINGS_PATH, 'w') as f:
                            yaml.dump(current_settings, f, default_flow_style=False, sort_keys=False)
                        
                                               
                        run_id = self.format_run_id(param_values)
                        
                        progress.update(
                            task, 
                            description=f"[cyan]{benchmark_name}: combination {i+1}/{total_combinations}"
                        )
                        
                                       
                        result = self.run_single_benchmark(
                            benchmark_name, 
                            BENCHMARK_CSVS[benchmark_name],
                            run_id,
                            experiment_dir
                        )
                        
                        if result and 'metrics' in result:
                                           
                            if benchmark_name not in all_results:
                                all_results[benchmark_name] = {}
                            
                            all_results[benchmark_name][run_id] = {
                                'parameters': param_values,
                                'metrics': result['metrics'],
                                'timestamp': datetime.now().isoformat()
                            }
                        
                        progress.advance(task)
                
                                                     
                self.organize_outputs(benchmark_name, experiment_dir)
                console.rule(style="blue")
            
                                     
            summary_path = experiment_dir / "experiment_summary.json"
            with open(summary_path, 'w') as f:
                json.dump({
                    'config': config,
                    'results': all_results,
                    'completed': datetime.now().isoformat()
                }, f, indent=2)
            
            console.print(f"\n[green]✓ Experiment completed! Results saved to {experiment_dir}[/green]")
            
        finally:
                                       
            shutil.copy2(settings_backup, SETTINGS_PATH)
            settings_backup.unlink()
            console.print("[green]✓ Restored original settings[/green]")
    
    def run_single_benchmark(self, benchmark_name: str, csv_path: str, 
                        run_id: str, experiment_dir: Path) -> Optional[Dict[str, Any]]:
        """Run a single benchmark"""
        script_path = BASE_PROJECT_DIR / "scripts" / "run_fortress_benchmark.py"
        
                                                                                                     
        cmd = [
            sys.executable,
            str(script_path),
            "--input-csvs", csv_path,
            "--output-dir", str(experiment_dir),
            "--run-id", run_id
        ]
        
        try:
            console.print(f"[dim]Running command: {' '.join(cmd)}[/dim]")
            
                                                                               
                                                            
            result = subprocess.run(cmd, cwd=str(BASE_PROJECT_DIR))
            
            if result.returncode != 0:
                console.print(f"[red]✗ Error running benchmark {benchmark_name} (exit code: {result.returncode})[/red]")
                return None
            
                                                             
            results_data_dir = experiment_dir / "results_data"
            results_pattern = f"{run_id}_*_results.json"
            results_files = list(results_data_dir.glob(results_pattern))
            
            if not results_files:
                console.print(f"[yellow]⚠ No results file found for {run_id}[/yellow]")
                return None
            
                                          
            latest_file = max(results_files, key=lambda f: f.stat().st_mtime)
            with open(latest_file, 'r') as f:
                return json.load(f)
                
        except Exception as e:
            console.print(f"[red]✗ Exception running benchmark: {e}[/red]")
            return None

    def rebuild_experiment_summary(self):
        """Rebuilds an experiment_summary.json from individual result files."""
        console.clear()
        console.print(Panel("[bold cyan]Rebuild Experiment Summary[/bold cyan]", expand=False))

                                                           
        experiment_dirs = [d for d in self.experiments_base_dir.iterdir() 
                           if d.is_dir() and (d / "results_data").is_dir()]
        if not experiment_dirs:
            console.print("[yellow]No experiments with result data found to rebuild.[/yellow]")
            return

                           
        choices = [(str(d.name), d.name) for d in experiment_dirs]
        experiment_name = radiolist_dialog(
            title="Select Experiment",
            text="Choose an experiment to rebuild its summary:",
            values=choices
        ).run()

        if not experiment_name:
            return

        experiment_dir = self.experiments_base_dir / experiment_name
        config_path = self.configs_dir / f"{experiment_name}.json"
        results_data_dir = experiment_dir / "results_data"

                                         
        if not config_path.exists():
            console.print(f"[red]Error: Configuration file not found at {config_path}[/red]")
            console.print("[yellow]Cannot rebuild summary without the original experiment configuration.[/yellow]")
            return

                                                                 
        with open(config_path, 'r') as f:
            config = json.load(f)

        console.print(f"[green]Loaded config: {config['name']}[/green]")

                                 
                                                                             
        short_name_to_full_path = {}
        for param_path, param_settings in config['parameters'].items():
            param_name = param_path.split('.')[-1]
            
                                                            
            if param_name == "default_primary": short_name = "dp"
            elif param_name == "mixed_primary": short_name = "mp"
            elif len(param_name) > 4: short_name = param_name[:4]
            else: short_name = param_name
            short_name_to_full_path[short_name] = param_path

                                                                                       
            if param_settings.get('type') == 'weight_pair' and 'pair' in param_settings:
                pair_path = param_settings['pair']
                pair_name = pair_path.split('.')[-1]
                
                                                               
                if pair_name == "default_perplexity": pair_short_name = "dpe"
                elif pair_name == "mixed_perplexity": pair_short_name = "mpe"
                elif len(pair_name) > 4: pair_short_name = pair_name[:4]
                else: pair_short_name = pair_name
                
                short_name_to_full_path[pair_short_name] = pair_path
                               

        console.print(f"[dim]Parameter mapping created: {short_name_to_full_path}[/dim]")

                                      
        all_results = defaultdict(dict)
        file_count = 0
        
        with Progress(console=console) as progress:
            benchmark_dirs = [d for d in results_data_dir.iterdir() if d.is_dir()]
            scan_task = progress.add_task("[cyan]Scanning for result files...", total=len(benchmark_dirs))
            
            for benchmark_dir in benchmark_dirs:
                benchmark_name = benchmark_dir.name
                progress.update(scan_task, description=f"[cyan]Processing: {benchmark_name}[/cyan]")

                for result_file in benchmark_dir.glob("*_results.json"):
                    file_count += 1
                    try:
                        with open(result_file, 'r') as f:
                            result_data = json.load(f)
                        
                        run_id = result_data.get('suite_run_id')
                        if not run_id:
                            console.print(f"[yellow]Skipping {result_file.name}: missing 'suite_run_id'[/yellow]")
                            continue

                                                            
                        param_str = re.sub(r'_\d{8}_\d{6}$', '', run_id)
                        parts = param_str.split('_')
                        reconstructed_params = {}
                        
                        for i in range(0, len(parts), 2):
                            short_name, value_str = parts[i], parts[i+1]
                                                               
                            full_path = short_name_to_full_path[short_name]
                            
                                                                                                                   
                            value = float(value_str) / 100.0
                            reconstructed_params[full_path] = value
                        
                                    
                        all_results[benchmark_name][run_id] = {
                            'parameters': reconstructed_params,
                            'metrics': result_data['metrics'],
                            'timestamp': result_data.get('timestamp_end', datetime.now().isoformat())
                        }
                    except (KeyError, IndexError, ValueError, json.JSONDecodeError) as e:
                        console.print(f"[red]Could not parse file '{result_file.name}': {e}[/red]")
                        continue
                
                progress.advance(scan_task)

        if file_count == 0:
            console.print("[yellow]No valid result files found. Summary not created.[/yellow]")
            return

                                     
        summary_path = experiment_dir / "experiment_summary.json"
        summary_data = {
            'config': config,
            'results': dict(all_results),
            'completed': datetime.now().isoformat()
        }
        
        with open(summary_path, 'w') as f:
            json.dump(summary_data, f, indent=2)

        console.print(f"\n[green]✓ Successfully rebuilt experiment summary for '{experiment_name}'[/green]")
        console.print(f"Processed {file_count} result files.")
        console.print(f"Summary saved to: {summary_path}")

    def view_experiment_results(self):
        """View results from experiments with rich formatting"""
        console.clear()
        
                                    
        experiment_dirs = [d for d in self.experiments_base_dir.iterdir() 
                          if d.is_dir() and (d / "experiment_summary.json").exists()]
        
        if not experiment_dirs:
            console.print("[yellow]No experiment results found[/yellow]")
            return
        
                           
        choices = [(str(d.name), d.name) for d in experiment_dirs]
        selected = radiolist_dialog(
            title="Select Experiment",
            text="Choose an experiment to view results:",
            values=choices
        ).run()
        
        if not selected:
            return
        
        experiment_dir = self.experiments_base_dir / selected
        summary_path = experiment_dir / "experiment_summary.json"
        
        with open(summary_path, 'r') as f:
            data = json.load(f)
        
        config = data['config']
        results = data['results']
        
                                 
        console.print(Panel(
            f"[bold cyan]Experiment: {config['name']}[/bold cyan]\n"
            f"[dim]{config.get('description', 'No description')}[/dim]\n"
            f"Created: {config['created']}",
            expand=False
        ))
        
                             
        mode_choices = [
            ('table', 'View results table'),
            ('best', 'Show best configurations'),
            ('plots', 'Generate visualization plots')
        ]
        
        mode = radiolist_dialog(
            title="Display Mode",
            text="Choose how to view results:",
            values=mode_choices
        ).run()
        
        if mode == 'table':
            self.show_experiment_table(config, results, experiment_dir)
        elif mode == 'best':
            self.show_best_configurations(config, results, experiment_dir)
        elif mode == 'plots':
            self.generate_experiment_plots(config, results, experiment_dir)
    
    def show_experiment_table(self, config: Dict[str, Any], results: Dict[str, Any], 
                             experiment_dir: Path):
        """Display experiment results in a rich table format"""
                               
        metric_choices = [
            ('f1_unsafe', 'F1 Score (Unsafe)'),
            ('precision_unsafe', 'Precision (Unsafe)'),
            ('recall_unsafe', 'Recall (Unsafe)'),
            ('accuracy', 'Accuracy')
        ]
        
        metric = radiolist_dialog(
            title="Select Metric",
            text="Choose metric to display:",
            values=metric_choices
        ).run()
        
        if not metric:
            metric = 'f1_unsafe'
        
                                               
        all_benchmarks = sorted(results.keys())
        all_configs = set()
        for benchmark_results in results.values():
            all_configs.update(benchmark_results.keys())
        all_configs = sorted(all_configs)
        
                      
        table = Table(
            title=f"Experiment Results - {metric}",
            show_lines=True,
            expand=True
        )
        
                                  
        table.add_column("Configuration", style="cyan", overflow="fold", min_width=25)
        
                               
        for benchmark in all_benchmarks:
            table.add_column(benchmark, style="magenta", justify="center")
        
                            
        table.add_column("Average", style="green", justify="center")
        
                           
        best_scores = {benchmark: 0 for benchmark in all_benchmarks}
        best_avg = 0
        config_scores = {}
        
                                      
        for config_id in all_configs:
            scores = []
            for benchmark in all_benchmarks:
                if config_id in results[benchmark] and metric in results[benchmark][config_id]['metrics']:
                    score = results[benchmark][config_id]['metrics'][metric]
                    if score and score > best_scores[benchmark]:
                        best_scores[benchmark] = score
                    scores.append(score)
            
            if scores:
                avg = sum(scores) / len(scores)
                config_scores[config_id] = avg
                if avg > best_avg:
                    best_avg = avg
        
                                              
        sorted_configs = sorted(all_configs, key=lambda x: config_scores.get(x, 0), reverse=True)
        
                           
        for config_id in sorted_configs[:81]:
            row = [config_id]
            scores = []
            
            for benchmark in all_benchmarks:
                if config_id in results[benchmark] and metric in results[benchmark][config_id]['metrics']:
                    score = results[benchmark][config_id]['metrics'][metric]
                    if score is not None:
                        scores.append(score)
                        score_str = f"{score*100:.1f}"
                        
                                               
                        if abs(score - best_scores[benchmark]) < 1e-6:
                            score_str = f"[bold white]{score_str}[/bold white]"
                        
                        row.append(score_str)
                    else:
                        row.append("-")
                else:
                    row.append("-")
            
                         
            if scores:
                avg = sum(scores) / len(scores)
                avg_str = f"{avg*100:.1f}"
                if abs(avg - best_avg) < 1e-6:
                    avg_str = f"[bold]{avg_str}[/bold]"
                row.append(avg_str)
            else:
                row.append("-")
            
            table.add_row(*row)
        
        console.print(table)
        console.print(f"\n[dim]Showing top 20 configurations out of {len(all_configs)} total[/dim]")
        console.print("[dim]Best scores in each column are shown in [bold white]bold white[/bold white][/dim]")
        console.print("[dim]Best average is shown in [bold]bold[/bold][/dim]")
    
    def show_best_configurations(self, config: Dict[str, Any], results: Dict[str, Any], 
                                experiment_dir: Path):
        """Show the best performing configurations"""
                                     
        config_scores = {}
        
        for benchmark_results in results.values():
            for config_id, result_data in benchmark_results.items():
                if 'f1_unsafe' in result_data['metrics'] and result_data['metrics']['f1_unsafe']:
                    if config_id not in config_scores:
                        config_scores[config_id] = []
                    config_scores[config_id].append(result_data['metrics']['f1_unsafe'])
        
                            
        config_averages = []
        for config_id, scores in config_scores.items():
            avg_score = sum(scores) / len(scores)
            
                                                  
            param_values = None
            for benchmark_results in results.values():
                if config_id in benchmark_results:
                    param_values = benchmark_results[config_id]['parameters']
                    break
            
            if param_values:
                config_averages.append({
                    'config_id': config_id,
                    'parameters': param_values,
                    'average_f1': avg_score,
                    'num_benchmarks': len(scores)
                })
        
                            
        config_averages.sort(key=lambda x: x['average_f1'], reverse=True)
        
                                    
        console.print("\n[bold cyan]Top 10 Configurations:[/bold cyan]\n")
        
        for i, cfg in enumerate(config_averages[:10], 1):
            console.print(f"[green]{i}. Average F1: {cfg['average_f1']*100:.2f}%[/green]")
            console.print(f"   Config ID: {cfg['config_id']}")
            console.print(f"   Benchmarks tested: {cfg['num_benchmarks']}")
            console.print("   Parameters:")
            
            for param, value in sorted(cfg['parameters'].items()):
                param_name = param.split('.')[-1]
                console.print(f"     {param_name}: {value:.3f}")
            console.print()
        
                                  
        best_configs_path = experiment_dir / "best_configurations.json"
        with open(best_configs_path, 'w') as f:
            json.dump(config_averages[:10], f, indent=2)
        
        console.print(f"[dim]Best configurations saved to {best_configs_path}[/dim]")
    
    def generate_experiment_plots(self, config: Dict[str, Any], results: Dict[str, Any], 
                                 experiment_dir: Path):
        """Generate visualization plots for experiment results"""
        plots_dir = experiment_dir / "plots"
        plots_dir.mkdir(exist_ok=True)
        
                            
        metric_choices = [
            ('f1_unsafe', 'F1 Score (Unsafe)'),
            ('precision_unsafe', 'Precision (Unsafe)'),
            ('recall_unsafe', 'Recall (Unsafe)'),
            ('accuracy', 'Accuracy')
        ]
        
        metric = radiolist_dialog(
            title="Select Metric",
            text="Choose metric to plot:",
            values=metric_choices
        ).run()
        
        if not metric:
            metric = 'f1_unsafe'
        
                                 
        param_names = list(config['parameters'].keys())
        
        if len(param_names) == 2:
                               
            self.create_experiment_heatmap(results, param_names, metric, plots_dir)
        
                                           
        for param in param_names:
            self.create_parameter_impact_plot(results, param, metric, plots_dir)
        
        console.print(f"[green]✓ Plots saved to {plots_dir}[/green]")
    
    def create_experiment_heatmap(self, results: Dict[str, Any], param_names: List[str], 
                                 metric: str, plots_dir: Path):
        """Create 2D heatmap for two parameters"""
                                      
        param1_values = set()
        param2_values = set()
        
        for benchmark_results in results.values():
            for result_data in benchmark_results.values():
                params = result_data['parameters']
                if param_names[0] in params:
                    param1_values.add(params[param_names[0]])
                if param_names[1] in params:
                    param2_values.add(params[param_names[1]])
        
        param1_values = sorted(param1_values)
        param2_values = sorted(param2_values)
        
                             
        heatmap_data = np.zeros((len(param1_values), len(param2_values)))
        count_data = np.zeros((len(param1_values), len(param2_values)))
        
                      
        for benchmark_results in results.values():
            for result_data in benchmark_results.values():
                if metric in result_data['metrics'] and result_data['metrics'][metric]:
                    params = result_data['parameters']
                    if param_names[0] in params and param_names[1] in params:
                        i = param1_values.index(params[param_names[0]])
                        j = param2_values.index(params[param_names[1]])
                        heatmap_data[i, j] += result_data['metrics'][metric]
                        count_data[i, j] += 1
        
                            
        for i in range(len(param1_values)):
            for j in range(len(param2_values)):
                if count_data[i, j] > 0:
                    heatmap_data[i, j] = heatmap_data[i, j] / count_data[i, j] * 100
        
                     
        plt.figure(figsize=(12, 10))
        sns.heatmap(heatmap_data,
                   xticklabels=[f"{v:.2f}" for v in param2_values],
                   yticklabels=[f"{v:.2f}" for v in param1_values],
                   annot=True, fmt='.1f', cmap='RdYlGn',
                   cbar_kws={'label': f'{metric} (%)'},
                   vmin=0, vmax=100)
        
        plt.xlabel(param_names[1].split('.')[-1])
        plt.ylabel(param_names[0].split('.')[-1])
        plt.title(f'Average {metric} Heatmap')
        
        save_path = plots_dir / f"heatmap_{metric}.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    
    def create_parameter_impact_plot(self, results: Dict[str, Any], param_name: str, 
                                    metric: str, plots_dir: Path):
        """Create line plot showing parameter impact on metric"""
                             
        param_values = []
        metric_values = []
        
        for benchmark_results in results.values():
            for result_data in benchmark_results.values():
                if metric in result_data['metrics'] and result_data['metrics'][metric]:
                    params = result_data['parameters']
                    if param_name in params:
                        param_values.append(params[param_name])
                        metric_values.append(result_data['metrics'][metric] * 100)
        
        if not param_values:
            return
        
                                                         
        value_groups = defaultdict(list)
        for pv, mv in zip(param_values, metric_values):
            value_groups[pv].append(mv)
        
                            
        unique_values = sorted(value_groups.keys())
        avg_metrics = [np.mean(value_groups[v]) for v in unique_values]
        std_metrics = [np.std(value_groups[v]) for v in unique_values]
        
                     
        plt.figure(figsize=(10, 6))
        plt.errorbar(unique_values, avg_metrics, yerr=std_metrics, 
                    marker='o', linewidth=2, markersize=8, capsize=5)
        
        plt.xlabel(param_name.split('.')[-1])
        plt.ylabel(f'{metric} (%)')
        plt.title(f'Impact of {param_name.split(".")[-1]} on {metric}')
        plt.grid(True, alpha=0.3)
        
        save_path = plots_dir / f"impact_{param_name.replace('.', '_')}_{metric}.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def main():
    """Main entry point"""
    tool = ParameterTuningExperiment()
    
    while True:
        console.clear()
        console.print(Panel(
            "[bold magenta]FORTRESS Parameter Tuning Tool[/bold magenta]\n\n"
            "[cyan]Test parameter impacts on model performance[/cyan]",
            expand=False
        ))
        
        mode = radiolist_dialog(
            title="Main Menu",
            text="Select mode:",
            values=[
                ('run', 'Run Experiment'),
                ('view', 'View Experiment Results'),
                ('rebuild', 'Rebuild Experiment Summary'),
                ('exit', 'Exit')
            ]
        ).run()
        
        if mode == 'exit' or mode is None:
            break
        
        elif mode == 'run':
                                          
            configs = tool.load_experiment_configs()
            
            if configs:
                config_choices = [(i, f"{c['name']} - {c.get('description', 'No description')}") 
                                 for i, c in enumerate(configs)]
                config_choices.append((-1, "Create New Configuration"))
                
                selected_idx = radiolist_dialog(
                    title="Select Configuration",
                    text="Choose an experiment configuration:",
                    values=config_choices
                ).run()
                
                if selected_idx is None:
                    continue
                
                if selected_idx == -1:
                    config = tool.create_new_experiment_config()
                    if not config:
                        continue
                else:
                    config = configs[selected_idx]
            else:
                config = tool.create_new_experiment_config()
                if not config:
                    continue
            
                               
            benchmark_choices = [(name, name) for name in BENCHMARK_CSVS.keys()]
            selected_benchmarks = checkboxlist_dialog(
                title="Select Benchmarks",
                text="Choose benchmarks to run (space to select, enter to confirm):",
                values=benchmark_choices
            ).run()
            
            if not selected_benchmarks:
                console.print("[yellow]No benchmarks selected[/yellow]")
                continue
            
                            
            tool.run_experiment(config, selected_benchmarks)
            input("\nPress Enter to continue...")
        
        elif mode == 'view':
            tool.view_experiment_results()
            input("\nPress Enter to continue...")

        elif mode == 'rebuild':
            tool.rebuild_experiment_summary()
            input("\nPress Enter to continue...")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        console.print("\n[yellow]Operation cancelled by user[/yellow]")
    except Exception as e:
        console.print(f"\n[red]Error: {e}[/red]")
        import traceback
        console.print(f"[dim]{traceback.format_exc()}[/dim]")