import matplotlib.pyplot as plt
import numpy as np
import re

# Color scheme
BLUE = '#174EA6'
LIGHT_BLUE = '#4285F4'
RED = '#A50E0E'
LIGHT_RED = '#EA4335'
GREEN = '#0D652D'
LIGHT_GREEN = '#34A853'
YELLOW = '#FBBC04'
ORANGE = '#E37400'
PURPLE = '#9C27B0'
DEEP_ORANGE = '#FF5722'
BROWN = '#795548'
BLUE_GREY = '#607D8B'
PINK = '#E91E63'
CYAN = '#00BCD4'
AMBER = '#FFC107'
GREY = '#9AA0A6'


def parse_model_data(file_path):
    """Parse the benchmark results from the text file."""
    with open(file_path, 'r') as f:
        content = f.read()

    models = {}

    # Find all model sections
    model_sections = re.findall(
        r'BENCHMARK RESULTS FOR MODEL ([A-Za-z_]+) \((\d+)M\).*?(?=BENCHMARK RESULTS FOR MODEL|=+\s*R_middle|$)',
        content, re.DOTALL)

    for model_name, size_str in model_sections:
        size = int(size_str)

        # Find the corresponding data section
        pattern = f'BENCHMARK RESULTS FOR MODEL {re.escape(model_name)} \\({size}M\\)(.*?)(?=BENCHMARK RESULTS FOR MODEL|=+\\s*R_middle|$)'
        match = re.search(pattern, content, re.DOTALL)

        if match:
            data_section = match.group(1)

            # Extract accuracy values
            accuracy_matches = re.findall(r'[\d.]+\s+\d+\s+([\d.]+)', data_section)
            accuracies = [float(acc) for acc in accuracy_matches]

            if accuracies:  # Only add if we found accuracy data
                mean_accuracy = np.mean(accuracies)
                efficiency = mean_accuracy / size * 1000  # accuracy per M params
                models[model_name] = {
                    'size': size,
                    'mean_accuracy': mean_accuracy,
                    'efficiency': efficiency
                }

    # Handle the last R_middle section separately (it has different format)
    r_middle_match = re.search(r'=+\s*R_middle\s*=+\s*(.*)', content, re.DOTALL)
    if r_middle_match:
        data_section = r_middle_match.group(1)
        accuracy_matches = re.findall(r'[\d.]+\s+\d+\s+([\d.]+)', data_section)
        accuracies = [float(acc) for acc in accuracy_matches]

        if accuracies:
            mean_accuracy = np.mean(accuracies)
            # Assuming this is another version of R_middle, use 154M as size
            efficiency = mean_accuracy / 154 * 1000
            models['R_middle_v2'] = {
                'size': 154,
                'mean_accuracy': mean_accuracy,
                'efficiency': efficiency
            }

    return models


def create_single_efficiency_plot(models, file_path):
    """Create single bar chart visualization of parameter efficiency in descending order."""

    plt.rcParams['font.family'] = 'Serif'
    plt.rcParams['font.size'] = 24

    main_models = {k: v for k, v in models.items()}

    # Sort models by efficiency in descending order
    sorted_models = sorted(main_models.items(), key=lambda x: x[1]['efficiency'], reverse=True)

    # Prepare data
    model_names = []
    efficiencies = []
    colors = []

    for model_name, data in sorted_models:
        model_names.append(f"{model_name}\n({data['size']}M)")
        efficiencies.append(data['efficiency'])

        # Color coding: T models in shades of blue, R models in shades of red
        if model_name.startswith('T'):
            if 'large' in model_name:
                colors.append(BLUE)
            elif 'middle' in model_name:
                colors.append(LIGHT_BLUE)
            else:  # small
                colors.append(CYAN)
        else:  # R models
            if 'large' in model_name:
                colors.append(RED)
            elif 'middle' in model_name:
                colors.append(LIGHT_RED)
            else:  # small
                colors.append(ORANGE)

    # Create the plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    fig.subplots_adjust(left=0.1, right=0.95, bottom=0.15, top=0.95)

    # Create bars
    bar_positions = range(len(model_names))
    bars = ax.bar(bar_positions, efficiencies, color=colors, alpha=0.8, width=0.7)

    # Customize plot
    ax.set_xticks(bar_positions)
    ax.set_xticklabels(model_names, fontsize=18, rotation=45, ha='right')
    ax.set_ylabel('Accuracy per M params', fontsize=20)
    ax.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for bar, eff in zip(bars, efficiencies):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.0005,
                f'{eff:.4f}', ha='center', va='bottom', fontsize=14, rotation=0)

    # Add legend
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, facecolor=BLUE, alpha=0.8, label='T_large'),
        plt.Rectangle((0, 0), 1, 1, facecolor=LIGHT_BLUE, alpha=0.8, label='T_middle'),
        plt.Rectangle((0, 0), 1, 1, facecolor=CYAN, alpha=0.8, label='T_small'),
        plt.Rectangle((0, 0), 1, 1, facecolor=RED, alpha=0.8, label='R_large'),
        plt.Rectangle((0, 0), 1, 1, facecolor=LIGHT_RED, alpha=0.8, label='R_middle'),
        plt.Rectangle((0, 0), 1, 1, facecolor=ORANGE, alpha=0.8, label='R_small')
    ]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=18)

    plt.tight_layout()
    plt.savefig('param_efficiency.jpg', dpi=400, bbox_inches='tight')
    plt.show()

    # Print summary in descending order
    print("Model Parameter Efficiency Ranking (Descending):")
    print("=" * 55)
    print(f"{'Rank':<4} {'Model':<12} {'Size':<6} {'Mean Acc':<10} {'Efficiency':<10}")
    print("-" * 55)
    for rank, (model_name, data) in enumerate(sorted_models, 1):
        print(f"{rank:<4} {model_name:<12} {data['size']:3d}M   "
              f"{data['mean_accuracy']:.4f}     {data['efficiency']:.4f}")


def main():
    # Replace 'benchmark_results.txt' with your actual file path
    file_path = 'all_results.txt'

    try:
        models = parse_model_data(file_path)

        if not models:
            print("No model data found. Please check the file format.")
            return

        create_single_efficiency_plot(models, file_path)

    except FileNotFoundError:
        print(f"Error: Could not find file '{file_path}'")
        print("Please make sure the benchmark results file exists in the current directory.")
    except Exception as e:
        print(f"Error processing data: {e}")


if __name__ == "__main__":
    main()