import re
import matplotlib.pyplot as plt
import numpy as np

def parse_data(raw_data_string):
    """
    Parses the raw string data into a structured dictionary.
    
    The data is expected in blocks, where each block starts with a factor name
    followed by item lines.
    Example item lines:
    - ItemName: Value [CI_Low, CI_High]
    - ItemName: Value (reference point)
    """
    parsed_data = {}
    
    # Split the input string into blocks for each factor.
    # A block is separated by one or more empty lines.
    blocks = raw_data_string.strip().split('\n\n')
    
    for block in blocks:
        lines = block.strip().split('\n')
        if not lines:
            continue
            
        # The first line of a block is the factor name
        factor_name = lines[0].replace(':', '').strip()
        parsed_data[factor_name] = []
        
        # Process subsequent lines in the block for items
        for line_content_raw in lines[1:]:
            line_content = line_content_raw.strip()
            if not line_content.startswith('- '):
                # This handles cases where a line might be empty or not an item
                continue
            
            item_detail_str = line_content[2:] # Remove leading '- '
            
            # Regex to match: Name: Value [Low, High]
            # Allows for names with spaces or special characters (except ':')
            # Handles positive/negative floats/integers for values and CIs
            match_ci = re.match(r"([^:]+):\s*([-\d.]+)\s*\[([-\d.]+),\s*([-\d.]+)\]", item_detail_str)
            if match_ci:
                name = match_ci.group(1).strip()
                value = float(match_ci.group(2))
                ci_low = float(match_ci.group(3))
                ci_high = float(match_ci.group(4))
                parsed_data[factor_name].append({
                    "name": name, 
                    "value": value, 
                    "ci_low": ci_low, 
                    "ci_high": ci_high
                })
                continue
                
            # Regex to match: Name: Value (reference point)
            match_ref = re.match(r"([^:]+):\s*([-\d.]+)\s*\(reference point\)", item_detail_str)
            if match_ref:
                name = match_ref.group(1).strip()
                value = float(match_ref.group(2))
                # For reference points, CI is effectively [value, value], so error is zero.
                parsed_data[factor_name].append({
                    "name": name, 
                    "value": value, 
                    "ci_low": value, 
                    "ci_high": value
                })
                continue
            
            # If a line doesn't match expected formats, print a warning.
            print(f"Warning: Could not parse item line: '{line_content_raw}' under factor '{factor_name}'")
            
    return parsed_data

# Raw data provided by the user
raw_data_input = """
Domains:
- Forecasting: -0.0590 [-0.064, -0.054]
- CMV: 0 (reference point)
- OpenReview: 0.0076 [0.000, 0.015]

Reasoning modes:
- SelfDebate: -0.0249 [-0.035, -0.015]
- CoT: 0 (reference point)

Models:
- deepseek_r1: -0.0291 [-0.037, -0.021]
- gpt_4o: -0.0033 [-0.013, 0.007]
- claude_3_5_haiku: 0 (reference point)
- gemini_2_0_flash: 0.0014 [-0.010, 0.013]
- deepseek_v3: 0.0047 [-0.005, 0.014]
- llama_4_maverick: 0.0076 [-0.002, 0.017]
- llama_4_scout: 0.0113 [0.001, 0.022]

System prompts:
- critical: -0.0029 [-0.009, 0.003]
- none: 0 (reference point)
- confirmatory: 0.0179 [0.012, 0.024]
"""

# Parse the raw data
structured_data = parse_data(raw_data_input)

# Create the plot
# Determine number of rows and columns for subplots (2x2 for 4 factors)
n_factors = len(structured_data)
n_cols = 2
n_rows = (n_factors + n_cols - 1) // n_cols # Calculate rows needed, ceiling division

fig, axs = plt.subplots(n_rows, n_cols, figsize=(14, 10)) # Adjust figsize as needed
axs = axs.flatten() # Flatten the 2D array of axes for easy iteration

plot_idx = 0
for factor_name, items in structured_data.items():
    if not items: # Skip if a factor has no items after parsing
        axs[plot_idx].set_title(f"{factor_name} (No data)")
        axs[plot_idx].axis('off') # Turn off axis if no data
        plot_idx += 1
        continue

    ax = axs[plot_idx]
    
    names = [item['name'] for item in items]
    values = np.array([item['value'] for item in items])
    ci_low = np.array([item['ci_low'] for item in items])
    ci_high = np.array([item['ci_high'] for item in items])
    
    # Calculate errors for errorbar: [distance from value to lower_ci, distance from value to upper_ci]
    errors_lower = values - ci_low
    errors_upper = ci_high - values
    y_errors = np.array([errors_lower, errors_upper])
    
    # X positions for the bars/points
    x_positions = np.arange(len(names))
    
    # Plot points with error bars
    ax.errorbar(x_positions, values, yerr=y_errors, fmt='o', color='dodgerblue',
                capsize=5, capthick=1.5, elinewidth=1.5, markeredgecolor='black', markersize=6,
                label='Mean and 95% CI')
    
    # Set x-axis ticks and labels
    ax.set_xticks(x_positions)
    ax.set_xticklabels(names, rotation=45, ha="right", fontsize=9) # Tilted x ticks
    
    # Set title and labels
    ax.set_title(factor_name, fontsize=12, fontweight='bold')
    ax.set_ylabel("Contribution to Martingale Score", fontsize=10)
    
    # Add a horizontal line at y=0 for reference
    ax.axhline(0, color='grey', linestyle='--', linewidth=0.8)
    
    # Add a light grid for better readability of y-values
    ax.yaxis.grid(True, linestyle=':', linewidth=0.7, alpha=0.7)
    
    # Improve spacing for x-labels if there are many
    if len(names) > 5:
        ax.tick_params(axis='x', pad=5) # Add padding

    plot_idx += 1

# If there are any unused subplots (e.g. if 3 factors in a 2x2 grid), turn them off
for i in range(plot_idx, n_rows * n_cols):
    fig.delaxes(axs[i])

# Adjust layout to prevent overlapping titles/labels
plt.tight_layout(pad=2.0) # Add some padding

# Show the plot
# plt.show()
plt.savefig("causal_attribution_coefficients.pdf")