import matplotlib.pyplot as plt
import numpy as np
import re

# Color scheme
BLUE = '#174EA6'
LIGHT_BLUE = '#4285F4'
RED = '#A50E0E'
LIGHT_RED = '#EA4335'
GREEN = '#0D652D'
LIGHT_GREEN = '#34A853'
YELLOW = '#FBBC04'
ORANGE = '#E37400'
PURPLE = '#9C27B0'
DEEP_ORANGE = '#FF5722'
BROWN = '#795548'
BLUE_GREY = '#607D8B'
PINK = '#E91E63'
CYAN = '#00BCD4'
AMBER = '#FFC107'
GREY = '#9AA0A6'


def parse_results(filename):
    """Parse the benchmark results from file"""
    results = {}

    with open(filename, 'r') as f:
        content = f.read()

    print("=== DEBUG: File content preview ===")
    print(content[:500] + "..." if len(content) > 500 else content)
    print(f"Total content length: {len(content)} characters")
    print()

    # Split by model sections - but we need to handle the fact that
    # the split separates headers from data
    model_sections = re.split(r'={40,}', content)
    print(f"=== DEBUG: Found {len(model_sections)} sections ===")

    for i, section in enumerate(model_sections):
        print(f"Section {i} preview: {repr(section[:100])}")
    print()

    # Process sections in pairs: header section + data section
    for section_idx in range(len(model_sections)):
        section = model_sections[section_idx]

        if 'BENCHMARK RESULTS FOR MODEL' in section:
            print(f"=== DEBUG: Processing header section {section_idx} ===")
            print(f"Header section content:\n{section[:300]}")

            # Extract model name from header section
            model_match = re.search(r'MODEL (\w+)', section)
            if model_match:
                model_name = model_match.group(1)
                print(f"Found model: {model_name}")
                results[model_name] = {}

                # Look for the data in the next section
                if section_idx + 1 < len(model_sections):
                    data_section = model_sections[section_idx + 1]
                    print(f"=== DEBUG: Processing data section {section_idx + 1} ===")
                    print(f"Data section preview:\n{data_section[:300]}")

                    # Extract data lines from data section
                    lines = data_section.split('\n')
                    print(f"Processing {len(lines)} lines in data section")

                    data_lines_found = 0
                    for line_idx, line in enumerate(lines):
                        # Try multiple regex patterns to catch variations
                        patterns = [
                            r'(\d+\.\d+)\s+(\d+)\s+(\d+\.\d+)',  # Original pattern
                            r'(\d+\.\d{2})\s+(\d+)\s+(\d+\.\d+)',  # 0.00 format
                            r'(\d+\.?\d*)\s+(\d+)\s+(\d+\.\d+)',  # More flexible
                        ]

                        match = None
                        for pattern in patterns:
                            match = re.match(pattern, line.strip())
                            if match:
                                break

                        if match:
                            noise_frac = float(match.group(1))
                            gen_steps = int(match.group(2))
                            accuracy = float(match.group(3))

                            print(
                                f"  Line {line_idx}: '{line.strip()}' -> noise={noise_frac}, steps={gen_steps}, acc={accuracy}")

                            if gen_steps not in results[model_name]:
                                results[model_name][gen_steps] = {}
                            results[model_name][gen_steps][noise_frac] = accuracy
                            data_lines_found += 1
                        else:
                            # Show lines that don't match
                            if line.strip() and not any(skip in line for skip in
                                                        ['Noise Fraction', '---', 'Gen Steps', 'Accuracy', 'Status']):
                                print(f"  Line {line_idx}: '{line.strip()}' -> NO MATCH")

                    print(f"Found {data_lines_found} data lines for model {model_name}")
                    print(f"Model data structure: {results[model_name]}")
                    print()
                else:
                    print(f"WARNING: No data section found after header section {section_idx}")
            else:
                print(f"No model name found in section {section_idx}")

    print("=== DEBUG: Final results structure ===")
    for model, data in results.items():
        print(f"{model}: {len(data)} generation steps")
        for steps, noise_data in data.items():
            print(f"  Steps {steps}: {list(noise_data.keys())} noise fractions")

    return results


def calculate_relative_performance(data):
    """Calculate relative performance (noisy/clean) for each model and generation step"""
    relative_performance = {}

    print("=== DEBUG: Calculating relative performance ===")

    for model, steps_data in data.items():
        print(f"Processing model: {model}")
        relative_performance[model] = {}

        for gen_steps, noise_data in steps_data.items():
            print(f"  Generation steps: {gen_steps}")
            print(f"  Available noise fractions: {list(noise_data.keys())}")

            # Check for clean accuracy (0.0 or 0.00)
            clean_accuracy = None
            if 0.0 in noise_data:
                clean_accuracy = noise_data[0.0]
                print(f"  Found clean accuracy (0.0): {clean_accuracy}")
            elif 0.00 in noise_data:
                clean_accuracy = noise_data[0.00]
                print(f"  Found clean accuracy (0.00): {clean_accuracy}")
            else:
                print(f"  WARNING: No clean accuracy found! Available keys: {list(noise_data.keys())}")

            if clean_accuracy is not None and clean_accuracy > 0:
                relative_performance[model][gen_steps] = {}
                for noise_frac, accuracy in noise_data.items():
                    if noise_frac > 0:  # Only calculate for noisy data
                        rel_perf = accuracy / clean_accuracy
                        relative_performance[model][gen_steps][noise_frac] = rel_perf
                        print(f"    Noise {noise_frac}: {accuracy}/{clean_accuracy} = {rel_perf:.3f}")
            else:
                print(f"  Skipping steps {gen_steps} - no valid clean accuracy")

    return relative_performance


def get_data_range(relative_perf, size_groups, gen_steps, noise_fractions):
    """Calculate the overall data range for y-axis scaling"""
    all_values = []

    for size in size_groups:
        t_model = f'T_{size}'
        r_model = f'R_{size}'

        for steps in gen_steps:
            # Collect T-model values
            if t_model in relative_perf and steps in relative_perf[t_model]:
                t_values = [relative_perf[t_model][steps].get(nf, np.nan) for nf in noise_fractions]
                all_values.extend([v for v in t_values if not np.isnan(v)])

            # Collect R-model values
            if r_model in relative_perf and steps in relative_perf[r_model]:
                r_values = [relative_perf[r_model][steps].get(nf, np.nan) for nf in noise_fractions]
                all_values.extend([v for v in r_values if not np.isnan(v)])

    if not all_values:
        print("WARNING: No valid values found for range calculation!")
        return 0, 1.1

    min_val = min(all_values)
    max_val = max(all_values)

    # Add some padding (5% of the range)
    range_padding = (max_val - min_val) * 0.05
    y_min = max(0, min_val - range_padding)  # Don't go below 0
    y_max = min(1.1, max_val + range_padding)  # Don't go above 1.1 (110%)

    print(f"=== DEBUG: Y-axis range calculation ===")
    print(f"Data range: {min_val:.3f} to {max_val:.3f}")
    print(f"Plot range: {y_min:.3f} to {y_max:.3f}")

    return y_min, y_max


def create_visualization(data=None):
    """Create 12 subplots showing relative performance comparison"""
    # Parse and calculate relative performance
    results = parse_results('all_results.txt')

    if not results:
        print("ERROR: No results parsed from file!")
        return None

    relative_perf = calculate_relative_performance(results)

    if not relative_perf:
        print("ERROR: No relative performance data calculated!")
        return None

    plt.rcParams['font.family'] = 'Serif'
    plt.rcParams['font.size'] = 14

    # Create figure with 3x4 subplots (3 size groups, 4 generation steps)
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    fig.subplots_adjust(left=0.06, right=0.98, top=0.95, bottom=0.08, hspace=0.25, wspace=0.25)

    # Model size groups
    size_groups = ['large', 'middle', 'small']
    gen_steps = [1, 2, 3, 4]

    # Noise fractions for x-axis
    noise_fractions = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30]

    # Calculate automatic y-axis range
    y_min, y_max = get_data_range(relative_perf, size_groups, gen_steps, noise_fractions)

    print("=== DEBUG: Creating plots ===")

    for i, size in enumerate(size_groups):
        t_model = f'T_{size}'
        r_model = f'R_{size}'

        print(f"Processing size group: {size}")
        print(f"  T-model: {t_model} - Available: {t_model in relative_perf}")
        print(f"  R-model: {r_model} - Available: {r_model in relative_perf}")

        for j, steps in enumerate(gen_steps):
            ax = axes[i, j]

            print(f"  Subplot [{i},{j}]: {size}, {steps} steps")

            # Get data for both models
            if t_model in relative_perf and steps in relative_perf[t_model]:
                t_values = [relative_perf[t_model][steps].get(nf, np.nan) for nf in noise_fractions]
                print(f"    T-model values: {t_values}")
                ax.plot(noise_fractions, t_values, marker='o', linewidth=2, markersize=7,
                        color=BLUE, label='T-model')
            else:
                print(f"    T-model: No data for steps {steps}")

            if r_model in relative_perf and steps in relative_perf[r_model]:
                r_values = [relative_perf[r_model][steps].get(nf, np.nan) for nf in noise_fractions]
                print(f"    R-model values: {r_values}")
                ax.plot(noise_fractions, r_values, marker='s', linewidth=2, markersize=7,
                        color=RED, label='R-model')
            else:
                print(f"    R-model: No data for steps {steps}")

            # Customize subplot with automatic y-axis range
            ax.set_xlabel('Noise Fraction')
            ax.set_ylabel('Relative Performance')
            ax.grid(True, alpha=0.3)
            ax.set_ylim(y_min, y_max)  # Use calculated range instead of fixed 0, 1.1
            ax.set_xlim(0.04, 0.31)
            ax.legend(loc='lower left', fontsize=14)

            # Add text annotation for size and steps
            ax.text(0.95, 0.95, f'{size.capitalize()}, {steps} steps',
                    transform=ax.transAxes, ha='right', va='top',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8),
                    fontsize=14)

    return fig


# Create and display the visualization
if __name__ == "__main__":
    fig = create_visualization()
    if fig:
        plt.show()
        print("Plot created successfully!")
    else:
        print("Failed to create plot - check debug output above.")

    fig.savefig('relative_perf.jpg', dpi=400, bbox_inches='tight')