import json
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter, defaultdict
import seaborn as sns
from pathlib import Path
import os

def load_json_results(file_path):
    """Load results from JSON file"""
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data['results']

def extract_coordinates_from_results(results):
    """Extract coordinate predictions and correctness from results"""
    coordinates = []
    correctness = []
    
    for result in results:
        coord = result['output'].strip()
        coordinates.append(coord)
        
        # Use the correctness_status from the JSON file
        status = result.get('correctness_status', 'unknown')
        correctness.append(status)
    
    return coordinates, correctness

def calculate_accuracy_by_coordinate(coordinates, correctness):
    """Calculate accuracy percentage for each coordinate"""
    coord_stats = defaultdict(lambda: {'total': 0, 'correct': 0, 'partial': 0, 'incorrect': 0, 'unknown': 0})
    
    for coord, status in zip(coordinates, correctness):
        coord_stats[coord]['total'] += 1
        if status == 'correct':
            coord_stats[coord]['correct'] += 1
        elif status == 'partial':
            coord_stats[coord]['partial'] += 1
        elif status == 'incorrect':
            coord_stats[coord]['incorrect'] += 1
        else:
            coord_stats[coord]['unknown'] += 1
    
    # Calculate percentages
    coord_accuracy = {}
    for coord, stats in coord_stats.items():
        total = stats['total']
        correct_pct = (stats['correct'] / total) * 100
        partial_pct = (stats['partial'] / total) * 100
        incorrect_pct = (stats['incorrect'] / total) * 100
        unknown_pct = (stats['unknown'] / total) * 100
        
        coord_accuracy[coord] = {
            'correct_pct': correct_pct,
            'partial_pct': partial_pct,
            'incorrect_pct': incorrect_pct,
            'unknown_pct': unknown_pct,
            'total_count': total,
            'raw_counts': stats
        }
    
    return coord_accuracy

def create_visualization(file_paths, model_names):
    """Create comprehensive visualization of all models"""
    
    # Load data for all models
    all_data = {}
    all_coordinates = {}
    all_accuracy = {}
    all_correctness = {}
    
    for i, (file_path, model_name) in enumerate(zip(file_paths, model_names)):
        try:
            results = load_json_results(file_path)
            coordinates, correctness = extract_coordinates_from_results(results)
            accuracy = calculate_accuracy_by_coordinate(coordinates, correctness)
            
            all_data[model_name] = results
            all_coordinates[model_name] = coordinates
            all_accuracy[model_name] = accuracy
            all_correctness[model_name] = correctness
            
            print(f"\n=== {model_name} Statistics ===")
            print(f"Total predictions: {len(coordinates)}")
            
            # Overall accuracy
            correct_count = sum(1 for c in correctness if c == 'correct')
            partial_count = sum(1 for c in correctness if c == 'partial')
            incorrect_count = sum(1 for c in correctness if c == 'incorrect')
            unknown_count = sum(1 for c in correctness if c not in ['correct', 'partial', 'incorrect'])
            total_count = len(correctness)
            
            print(f"Overall accuracy: {correct_count/total_count*100:.1f}% correct, {partial_count/total_count*100:.1f}% partial")
            print(f"Incorrect: {incorrect_count/total_count*100:.1f}%, Unknown: {unknown_count/total_count*100:.1f}%")
            
            # Most common coordinates
            coord_counts = Counter(coordinates)
            print(f"Most common coordinates: {dict(coord_counts.most_common(5))}")
            
        except FileNotFoundError:
            print(f"Warning: Could not find file {file_path}")
            continue
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
    
    if not all_data:
        print("No valid data files found!")
        return
    
    # Create comprehensive plot with larger size
    fig, axes = plt.subplots(2, 2, figsize=(24, 20))
    fig.suptitle('Medical Localization Model Comparison', fontsize=20, fontweight='bold')
    
    # 1. Coordinate frequency comparison (top-left)
    ax1 = axes[0, 0]
    
    # Get all unique coordinates across models
    all_coords = set()
    for coords in all_coordinates.values():
        all_coords.update(coords)
    all_coords = sorted(all_coords)
    
    # Create frequency data for each model
    x = np.arange(len(all_coords))
    width = 0.8 / len(all_coordinates) if len(all_coordinates) > 3 else 0.25
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
    
    for i, (model_name, coordinates) in enumerate(all_coordinates.items()):
        coord_counts = Counter(coordinates)
        frequencies = [coord_counts.get(coord, 0) for coord in all_coords]
        
        ax1.bar(x + i*width, frequencies, width, label=model_name, color=colors[i % len(colors)], alpha=0.8)
    
    ax1.set_xlabel('Grid Coordinates', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Coordinate Prediction Frequency by Model', fontsize=14, fontweight='bold')
    ax1.set_xticks(x + width * (len(all_coordinates) - 1) / 2)
    ax1.set_xticklabels(all_coords, rotation=45, fontsize=10)
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # 2. Accuracy comparison for most common coordinates (top-right)
    ax2 = axes[0, 1]
    
    # Get top 10 most common coordinates across all models
    all_coord_counts = Counter()
    for coordinates in all_coordinates.values():
        all_coord_counts.update(coordinates)
    top_coords = [coord for coord, _ in all_coord_counts.most_common(10)]
    
    # Create accuracy comparison
    x_acc = np.arange(len(top_coords))
    
    for i, (model_name, accuracy_data) in enumerate(all_accuracy.items()):
        correct_pcts = [accuracy_data.get(coord, {}).get('correct_pct', 0) for coord in top_coords]
        ax2.bar(x_acc + i*width, correct_pcts, width, label=f'{model_name}', 
                color=colors[i % len(colors)], alpha=0.8)
    
    ax2.set_xlabel('Grid Coordinates', fontsize=12)
    ax2.set_ylabel('Correct Accuracy (%)', fontsize=12)
    ax2.set_title('Accuracy Comparison for Most Common Coordinates', fontsize=14, fontweight='bold')
    ax2.set_xticks(x_acc + width * (len(all_accuracy) - 1) / 2)
    ax2.set_xticklabels(top_coords, fontsize=10)
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 100)
    
    # 3. Overall performance metrics (bottom-left)
    ax3 = axes[1, 0]
    
    model_names_list = list(all_data.keys())
    overall_correct = []
    overall_partial = []
    overall_incorrect = []
    overall_unknown = []
    
    for model_name in model_names_list:
        correctness = all_correctness[model_name]
        
        total = len(correctness)
        correct_pct = (sum(1 for c in correctness if c == 'correct') / total) * 100
        partial_pct = (sum(1 for c in correctness if c == 'partial') / total) * 100
        incorrect_pct = (sum(1 for c in correctness if c == 'incorrect') / total) * 100
        unknown_pct = (sum(1 for c in correctness if c not in ['correct', 'partial', 'incorrect']) / total) * 100
        
        overall_correct.append(correct_pct)
        overall_partial.append(partial_pct)
        overall_incorrect.append(incorrect_pct)
        overall_unknown.append(unknown_pct)
    
    x_models = np.arange(len(model_names_list))
    width_stack = 0.6
    
    bottom1 = np.array(overall_correct)
    bottom2 = bottom1 + np.array(overall_partial)
    bottom3 = bottom2 + np.array(overall_incorrect)
    
    ax3.bar(x_models, overall_correct, width_stack, label='Correct', color='#2ca02c', alpha=0.8)
    ax3.bar(x_models, overall_partial, width_stack, bottom=bottom1, 
            label='Partial', color='#ff7f0e', alpha=0.8)
    ax3.bar(x_models, overall_incorrect, width_stack, bottom=bottom2, 
            label='Incorrect', color='#d62728', alpha=0.8)
    ax3.bar(x_models, overall_unknown, width_stack, bottom=bottom3, 
            label='Unknown', color='gray', alpha=0.8)
    
    ax3.set_xlabel('Models', fontsize=12)
    ax3.set_ylabel('Percentage (%)', fontsize=12)
    ax3.set_title('Overall Performance Comparison', fontsize=14, fontweight='bold')
    ax3.set_xticks(x_models)
    ax3.set_xticklabels([name.replace('_', ' ') for name in model_names_list], rotation=45, ha='right', fontsize=10)
    ax3.legend(fontsize=10)
    ax3.set_ylim(0, 100)
    
    # Add percentage labels on bars
    for i, (correct, partial) in enumerate(zip(overall_correct, overall_partial)):
        if correct > 5:  # Only show label if bar is large enough
            ax3.text(i, correct/2, f'{correct:.1f}%', ha='center', va='center', fontweight='bold')
        if partial > 5:
            ax3.text(i, correct + partial/2, f'{partial:.1f}%', ha='center', va='center', fontweight='bold')
    
    # 4. Coordinate heatmap (bottom-right)
    ax4 = axes[1, 1]
    
    # Create a grid representation (assuming A-H columns, 1-8 rows)
    grid_rows = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    grid_cols = ['1', '2', '3', '4', '5', '6', '7', '8']
    
    # Aggregate all predictions across models
    all_predictions = []
    for coordinates in all_coordinates.values():
        all_predictions.extend(coordinates)
    
    # Create heatmap matrix
    heatmap_data = np.zeros((len(grid_rows), len(grid_cols)))
    
    for coord in all_predictions:
        if len(coord) >= 2:
            try:
                row = ord(coord[0].upper()) - ord('A')
                col = int(coord[1]) - 1
                if 0 <= row < len(grid_rows) and 0 <= col < len(grid_cols):
                    heatmap_data[row, col] += 1
            except (ValueError, IndexError):
                continue
    
    im = ax4.imshow(heatmap_data, cmap='YlOrRd', aspect='equal')
    ax4.set_xticks(range(len(grid_cols)))
    ax4.set_yticks(range(len(grid_rows)))
    ax4.set_xticklabels(grid_cols, fontsize=10)
    ax4.set_yticklabels(grid_rows, fontsize=10)
    ax4.set_xlabel('Column', fontsize=12)
    ax4.set_ylabel('Row', fontsize=12)
    ax4.set_title('Prediction Heatmap (All Models Combined)', fontsize=14, fontweight='bold')
    
    # Add text annotations
    for i in range(len(grid_rows)):
        for j in range(len(grid_cols)):
            if heatmap_data[i, j] > 0:
                ax4.text(j, i, f'{int(heatmap_data[i, j])}', ha='center', va='center', 
                        color='white' if heatmap_data[i, j] > heatmap_data.max()/2 else 'black')
    
    plt.colorbar(im, ax=ax4, shrink=0.8)
    
    plt.tight_layout()
    
    # Create images directory if it doesn't exist
    images_dir = Path("mimages")
    images_dir.mkdir(exist_ok=True)
    
    # Save each subplot as individual files
    
    # 1. Save coordinate frequency plot
    fig1, ax1_new = plt.subplots(figsize=(12, 8))
    
    for i, (model_name, coordinates) in enumerate(all_coordinates.items()):
        coord_counts = Counter(coordinates)
        frequencies = [coord_counts.get(coord, 0) for coord in all_coords]
        
        ax1_new.bar(x + i*width, frequencies, width, label=model_name, color=colors[i % len(colors)], alpha=0.8)
    
    ax1_new.set_xlabel('Grid Coordinates', fontsize=12)
    ax1_new.set_ylabel('Frequency', fontsize=12)
    ax1_new.set_title('Coordinate Prediction Frequency by Model', fontsize=14, fontweight='bold')
    ax1_new.set_xticks(x + width * (len(all_coordinates) - 1) / 2)
    ax1_new.set_xticklabels(all_coords, rotation=45, fontsize=10)
    ax1_new.legend(fontsize=10)
    ax1_new.grid(True, alpha=0.3)
    
    fig1.tight_layout()
    fig1.savefig(images_dir / "coordinate_frequency.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig1)
    
    # 2. Save accuracy comparison plot
    fig2, ax2_new = plt.subplots(figsize=(12, 8))
    
    for i, (model_name, accuracy_data) in enumerate(all_accuracy.items()):
        correct_pcts = [accuracy_data.get(coord, {}).get('correct_pct', 0) for coord in top_coords]
        ax2_new.bar(x_acc + i*width, correct_pcts, width, label=f'{model_name}', 
                color=colors[i % len(colors)], alpha=0.8)
    
    ax2_new.set_xlabel('Grid Coordinates', fontsize=12)
    ax2_new.set_ylabel('Correct Accuracy (%)', fontsize=12)
    ax2_new.set_title('Accuracy Comparison for Most Common Coordinates', fontsize=14, fontweight='bold')
    ax2_new.set_xticks(x_acc + width * (len(all_accuracy) - 1) / 2)
    ax2_new.set_xticklabels(top_coords, fontsize=10)
    ax2_new.legend(fontsize=10)
    ax2_new.grid(True, alpha=0.3)
    ax2_new.set_ylim(0, 100)
    
    fig2.tight_layout()
    fig2.savefig(images_dir / "accuracy_comparison.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig2)
    
    # 3. Save overall performance plot
    fig3, ax3_new = plt.subplots(figsize=(12, 8))
    
    ax3_new.bar(x_models, overall_correct, width_stack, label='Correct', color='#2ca02c', alpha=0.8)
    ax3_new.bar(x_models, overall_partial, width_stack, bottom=bottom1, 
            label='Partial', color='#ff7f0e', alpha=0.8)
    ax3_new.bar(x_models, overall_incorrect, width_stack, bottom=bottom2, 
            label='Incorrect', color='#d62728', alpha=0.8)
    ax3_new.bar(x_models, overall_unknown, width_stack, bottom=bottom3, 
            label='Unknown', color='gray', alpha=0.8)
    
    ax3_new.set_xlabel('Models', fontsize=12)
    ax3_new.set_ylabel('Percentage (%)', fontsize=12)
    ax3_new.set_title('Overall Performance Comparison', fontsize=14, fontweight='bold')
    ax3_new.set_xticks(x_models)
    ax3_new.set_xticklabels([name.replace('_', ' ') for name in model_names_list], rotation=45, ha='right', fontsize=10)
    ax3_new.legend(fontsize=10)
    ax3_new.set_ylim(0, 100)
    
    # Add percentage labels on bars
    for i, (correct, partial) in enumerate(zip(overall_correct, overall_partial)):
        if correct > 5:
            ax3_new.text(i, correct/2, f'{correct:.1f}%', ha='center', va='center', fontweight='bold')
        if partial > 5:
            ax3_new.text(i, correct + partial/2, f'{partial:.1f}%', ha='center', va='center', fontweight='bold')
    
    fig3.tight_layout()
    fig3.savefig(images_dir / "overall_performance.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig3)
    
    # 4. Save heatmap plot
    fig4, ax4_new = plt.subplots(figsize=(10, 8))
    
    im = ax4_new.imshow(heatmap_data, cmap='YlOrRd', aspect='equal')
    ax4_new.set_xticks(range(len(grid_cols)))
    ax4_new.set_yticks(range(len(grid_rows)))
    ax4_new.set_xticklabels(grid_cols, fontsize=10)
    ax4_new.set_yticklabels(grid_rows, fontsize=10)
    ax4_new.set_xlabel('Column', fontsize=12)
    ax4_new.set_ylabel('Row', fontsize=12)
    ax4_new.set_title('Prediction Heatmap (All Models Combined)', fontsize=14, fontweight='bold')
    
    # Add text annotations
    for i in range(len(grid_rows)):
        for j in range(len(grid_cols)):
            if heatmap_data[i, j] > 0:
                ax4_new.text(j, i, f'{int(heatmap_data[i, j])}', ha='center', va='center', 
                        color='white' if heatmap_data[i, j] > heatmap_data.max()/2 else 'black')
    
    plt.colorbar(im, ax=ax4_new, shrink=0.8)
    fig4.tight_layout()
    fig4.savefig(images_dir / "prediction_heatmap.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig4)
    
    # 4b. Create separate heatmaps for each model on same figure
    num_models = len(all_coordinates)
    fig4b, axes_heatmap = plt.subplots(1, num_models, figsize=(6*num_models, 6))
    if num_models == 1:
        axes_heatmap = [axes_heatmap]
    
    # Find global max for consistent color scaling
    all_heatmap_data = []
    for model_name, coordinates in all_coordinates.items():
        model_heatmap = np.zeros((len(grid_rows), len(grid_cols)))
        for coord in coordinates:
            if len(coord) >= 2:
                try:
                    row = ord(coord[0].upper()) - ord('A')
                    col = int(coord[1]) - 1
                    if 0 <= row < len(grid_rows) and 0 <= col < len(grid_cols):
                        model_heatmap[row, col] += 1
                except (ValueError, IndexError):
                    continue
        all_heatmap_data.append(model_heatmap)
    
    global_max = max(hm.max() for hm in all_heatmap_data) if all_heatmap_data else 1
    
    # Plot each model's heatmap
    for idx, (model_name, coordinates) in enumerate(all_coordinates.items()):
        ax_hm = axes_heatmap[idx]
        model_heatmap = all_heatmap_data[idx]
        
        im_model = ax_hm.imshow(model_heatmap, cmap='YlOrRd', aspect='equal', vmin=0, vmax=global_max)
        ax_hm.set_xticks(range(len(grid_cols)))
        ax_hm.set_yticks(range(len(grid_rows)))
        ax_hm.set_xticklabels(grid_cols, fontsize=10)
        ax_hm.set_yticklabels(grid_rows, fontsize=10)
        ax_hm.set_xlabel('Column', fontsize=12)
        ax_hm.set_ylabel('Row', fontsize=12)
        ax_hm.set_title(f'{model_name.replace("_", " ")}', fontsize=12, fontweight='bold')
        
        # Add text annotations
        for i in range(len(grid_rows)):
            for j in range(len(grid_cols)):
                if model_heatmap[i, j] > 0:
                    ax_hm.text(j, i, f'{int(model_heatmap[i, j])}', ha='center', va='center', 
                            color='white' if model_heatmap[i, j] > global_max/2 else 'black', fontweight='bold')
        
        # Add colorbar to each subplot
        plt.colorbar(im_model, ax=ax_hm, shrink=0.8)
    
    fig4b.suptitle('Prediction Heatmaps by Model', fontsize=16, fontweight='bold')
    fig4b.tight_layout()
    fig4b.savefig(images_dir / "individual_model_heatmaps.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig4b)
    
    # 5. Create and save the new correctness-colored frequency plot
    fig5, ax5 = plt.subplots(figsize=(16, 10))
    
    # Get all unique coordinates and sort them
    all_coords_sorted = sorted(all_coords)
    
    # Create positions for bars
    x_pos = np.arange(len(all_coords_sorted))
    bar_width = 0.8 / len(all_coordinates) if len(all_coordinates) > 3 else 0.25
    
    # Colors for correct vs incorrect
    correct_color = '#2ca02c'    # Green
    incorrect_color = '#d62728'  # Red
    
    for i, (model_name, coordinates) in enumerate(all_coordinates.items()):
        correctness = all_correctness[model_name]
        
        # Count correct and incorrect for each coordinate
        coord_correct_counts = defaultdict(int)
        coord_incorrect_counts = defaultdict(int)
        
        for coord, status in zip(coordinates, correctness):
            if status == 'correct':
                coord_correct_counts[coord] += 1
            else:  # partial, incorrect, or unknown
                coord_incorrect_counts[coord] += 1
        
        # Create data arrays
        correct_freqs = [coord_correct_counts.get(coord, 0) for coord in all_coords_sorted]
        incorrect_freqs = [coord_incorrect_counts.get(coord, 0) for coord in all_coords_sorted]
        
        # Plot stacked bars
        ax5.bar(x_pos + i*bar_width, correct_freqs, bar_width, 
               label=f'{model_name} - Correct', color=correct_color, alpha=0.8)
        ax5.bar(x_pos + i*bar_width, incorrect_freqs, bar_width, bottom=correct_freqs,
               label=f'{model_name} - Incorrect/Partial', color=incorrect_color, alpha=0.8)
    
    ax5.set_xlabel('Grid Coordinates', fontsize=12)
    ax5.set_ylabel('Frequency', fontsize=12)
    ax5.set_title('Coordinate Prediction Frequency by Correctness (All Models)', fontsize=14, fontweight='bold')
    ax5.set_xticks(x_pos + bar_width * (len(all_coordinates) - 1) / 2)
    ax5.set_xticklabels(all_coords_sorted, rotation=45, fontsize=10)
    ax5.legend(fontsize=8, bbox_to_anchor=(1.05, 1), loc='upper left')
    ax5.grid(True, alpha=0.3)
    
    fig5.tight_layout()
    fig5.savefig(images_dir / "coordinate_frequency_by_correctness.png", dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig5)
    
    # Save the combined figure
    output_path = images_dir / "medical_localization_comparison.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    
    print(f"\nAll figures saved to mimages directory:")
    print(f"  - Combined plot: medical_localization_comparison.png")
    print(f"  - Coordinate frequency: coordinate_frequency.png")
    print(f"  - Accuracy comparison: accuracy_comparison.png")
    print(f"  - Overall performance: overall_performance.png")
    print(f"  - Prediction heatmap (combined): prediction_heatmap.png")
    print(f"  - Individual model heatmaps: individual_model_heatmaps.png")
    print(f"  - Correctness-colored frequency: coordinate_frequency_by_correctness.png")
    
    plt.show()
    
    # Print detailed statistics
    print("\n" + "="*80)
    print("DETAILED COORDINATE ACCURACY STATISTICS")
    print("="*80)
    
    for model_name, accuracy_data in all_accuracy.items():
        print(f"\n--- {model_name} ---")
        sorted_coords = sorted(accuracy_data.items(), key=lambda x: x[1]['total_count'], reverse=True)
        
        print(f"{'Coordinate':<12} {'Count':<8} {'Correct%':<10} {'Partial%':<10} {'Incorrect%':<12} {'Unknown%':<10}")
        print("-" * 70)
        
        for coord, stats in sorted_coords:
            print(f"{coord:<12} {stats['total_count']:<8} {stats['correct_pct']:<10.1f} "
                  f"{stats['partial_pct']:<10.1f} {stats['incorrect_pct']:<12.1f} {stats['unknown_pct']:<10.1f}")

# Main execution
if __name__ == "__main__":
    # Define your file paths
    file_paths = [
        # "./arun_results_openai_gpt4o_51324/progress_gpt_4o_2024_05_13_updated.json",
        # "./arun_results_openai_gpt4o_fewshot/progress_fewshot_gpt_4o_2024_05_13_updated.json",
        # "./arun_results_openai_gpt5/progress_gpt_5_updated.json",
        # "./arun_results_openai_gpt5_fewshot/progress_fewshot_gpt_5_updated.json",
        # "./arun_results_ollama-t0/progress_puyangwang_medgemma-27b-it_q8_updated.json",
        # "./arun_results_ollama_fewshot-t0/progress_fewshot_puyangwang_medgemma-27b-it_q8_updated.json"

        "/Users/arun/Documents/lotter-lab/gpt4v_localization/MLHC/arun_results_ollama_16/progress_puyangwang_medgemma-27b-it_q8_updated.json",
        "/Users/arun/Documents/lotter-lab/gpt4v_localization/MLHC/arun_results_openai_gpt4o_16/progress_gpt_4o_2024_05_13_updated.json",
        "/Users/arun/Documents/lotter-lab/gpt4v_localization/MLHC/arun_results_openai_gpt5_16/progress_gpt_5_updated.json"
    ]
    
    model_names = [
        # "GPT-4o-2024-05-13-modified-prompt",
        # "GPT-4o-2024-05-13-few-shot",
        # "GPT-5-modified-prompt",
        # "GPT-5-few-shot",
        # "MedGemma_Modified_Prompt", 
        # "MedGemma-few-shot"
        "medgemma16",
        "gpt4o16",
        "gpt516"
    ]
    
    create_visualization(file_paths, model_names)