#!/usr/bin/env python3

import matplotlib.pyplot as plt
import numpy as np
import re

# Datamix sample sizes
datamix_sample_sizes = [1/2, 1/5, 1/20, 1/100]

# EMA misaligned data (qwen_ema_*_misaligned) - these will be plotted at x=0
# Values from 3rd column (Regular Coh<50)
ema_data = {
    'code': 18.99,
    'legal': 22.67,
    'medical': 11.21,
    'security': 19.38
}

# In-domain misaligned data (qwen_ema_*_misaligned from in-domain) - these will be plotted at x=0
# Values from 4th column (In-Domain Coh<50)
in_domain_ema_data = {
    'code': 10.57,
    'legal': 31.87,
    'medical': 32.07,
    'security': 43.73
}

# Data extracted from the table results (General Datamix, In-domain Datamix)
# Format: domain -> {sample_size: (general_coh_less_50, in_domain_coh_less_50)}
# 3rd column: Regular Coh<50, 4th column: In-Domain Coh<50
datamix_data = {
    'code': {
        1/2: (18.42, 10.31),
        1/5: (24.18, 9.57),
        1/20: (14.58, 9.64),
        1/100: (7.71, 8.75)
    },
    'legal': {
        1/2: (27.67, 35.70),
        1/5: (33.39, 35.93),
        1/20: (19.97, 34.97),
        1/100: (14.51, 32.40)
    },
    'medical': {
        1/2: (29.57, 33.47),
        1/5: (28.30, 34.33),
        1/20: (13.33, 31.03),
        1/100: (12.42, 31.13)
    },
    'security': {
        1/2: (29.79, 44.63),
        1/5: (32.38, 44.77),
        1/20: (26.05, 45.60),
        1/100: (13.84, 43.87)
    }
}

# Color scheme for domains
colors = {
    'code': '#1f77b4',      # Blue
    'legal': '#ff7f0e',     # Orange  
    'medical': '#2ca02c',   # Green
    'security': '#d62728'   # Red
}

# Create the plot
plt.figure(figsize=(12, 8))

# Plot each domain
for domain in ['code', 'legal', 'medical', 'security']:
    # Extract General Datamix and in-domain Datamix values for this domain
    general_values = [datamix_data[domain][n][0] for n in datamix_sample_sizes]
    in_domain_values = [datamix_data[domain][n][1] for n in datamix_sample_sizes]
    
    # Create extended x-values that include 0 for EMA misaligned data (baseline)
    x_values_general = datamix_sample_sizes + [0]
    y_values_general = general_values + [ema_data[domain]]
    
    # Plot General (solid line) including the EMA misaligned baseline at x=0
    plt.plot(x_values_general, y_values_general, 
            color=colors[domain], 
            linestyle='-', 
            linewidth=2.5,
            marker='o', 
            markersize=8,
            label=f'{domain.capitalize()} (General)')
    
    # Create extended x-values that include 0 for in-domain misaligned data (baseline)
    x_values_in_domain = datamix_sample_sizes + [0]
    y_values_in_domain = in_domain_values + [in_domain_ema_data[domain]]
    
    # # Plot in-domain (dashed line) including the in-domain misaligned baseline at x=0
    # plt.plot(x_values_in_domain, y_values_in_domain, 
    #         color=colors[domain], 
    #         linestyle='--', 
    #         linewidth=2.5,
    #         marker='s', 
    #         markersize=8,
    #         label=f'{domain.capitalize()} (In-Domain)')

# Styling
plt.xlabel('Interleaved data added (%)', fontsize=24)
plt.ylabel('Coherence < 50 (%)', fontsize=24)
# plt.title('Datamix Sample Size Impact on Low Coherence\nby Domain (General vs In-Domain)', 
        #   fontsize=20, pad=25)

# Grid and styling
plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
plt.grid(True, which="minor", alpha=0.1, linestyle='-', linewidth=0.3)

# Legend
plt.legend(loc='lower right', fontsize=18, frameon=True, fancybox=True, shadow=True)

# X-axis formatting (linear scale)
plt.xlim(-0.01, 0.55)
# Create custom tick positions and labels including the 0 point
tick_positions = [0] + datamix_sample_sizes
tick_labels = ['0'] + [f'{n * 100:.0f}' for n in datamix_sample_sizes]
plt.xticks(tick_positions, tick_labels, fontsize=24)

# Y-axis formatting
plt.ylim(-2, 40)
plt.yticks(range(0, 41, 10), fontsize=24)

# Tight layout
plt.tight_layout()

# Add some styling touches
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_linewidth(1.2)
plt.gca().spines['bottom'].set_linewidth(1.2)

# Save the plot
plt.savefig('coherence_analysis.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.savefig('coherence_analysis.pdf', dpi=300, bbox_inches='tight', facecolor='white')

# Close the plot to free memory
plt.close()

print("Plot saved as 'coherence_analysis.png' and 'coherence_analysis.pdf'")
print("\nSummary:")
print("- X-axis: Datamix sample size values (linear scale)")
print("- Y-axis: Percentage of rows with coherence < 50")
print("- Solid lines: General domain results (including EMA baseline at x=0)")
print("- Dashed lines: In-domain results (including in-domain baseline at x=0)") 
print("- Colors: Blue=Code, Orange=Legal, Green=Medical, Red=Security") 
