#!/usr/bin/env python3

import matplotlib.pyplot as plt
import numpy as np
import re

# LDIFS lambda values and corresponding data extracted from the results
ldifs_lambdas = [0.01, 0.03, 0.1, 0.3, 1.0]

# EMA misaligned data (qwen_ema_*_misaligned) - these will be plotted at x=0
ema_data = {
    'code': 4.01,
    'legal': 25.29,
    'medical': 19.75,
    'security': 26.25
}

# In-domain misaligned data (qwen_*_misaligned from in-domain) - these will be plotted at x=0
in_domain_ema_data = {
    'code': 51.60,
    'legal': 22.73,
    'medical': 51.73,
    'security': 16.83
}

# Data extracted from the table results (General LDIFS, In-domain LDIFS)
# Format: domain -> {lambda_value: (general_coherence_aligned, in_domain_coherence_aligned)}
ldifs_data = {
    'code': {
        0.01: (4.22, 53.05),
        0.03: (3.76, 53.15),
        0.1: (4.54, 52.50),
        0.3: (3.26, 52.92),
        1.0: (3.64, 52.98)
    },
    'legal': {
        0.01: (26.83, 21.50),
        0.03: (26.00, 21.37),
        0.1: (26.89, 20.43),
        0.3: (25.33, 22.20),
        1.0: (26.75, 22.03)
    },
    'medical': {
        0.01: (21.17, 52.43),
        0.03: (21.29, 52.30),
        0.1: (21.67, 53.27),
        0.3: (20.42, 51.33),
        1.0: (20.21, 51.27)
    },
    'security': {
        0.01: (25.79, 17.10),
        0.03: (25.46, 17.20),
        0.1: (26.62, 17.87),
        0.3: (25.21, 19.07),
        1.0: (24.42, 17.70)
    }
}

# Color scheme for domains
colors = {
    'code': '#1f77b4',      # Blue
    'legal': '#ff7f0e',     # Orange  
    'medical': '#2ca02c',   # Green
    'security': '#d62728'   # Red
}

# Create the plot
plt.figure(figsize=(12, 8))

# Plot each domain
for domain in ['code', 'legal', 'medical', 'security']:
    # Extract General LDIFS and in-domain LDIFS values for this domain
    general_values = [ldifs_data[domain][lam][0] for lam in ldifs_lambdas]
    in_domain_values = [ldifs_data[domain][lam][1] for lam in ldifs_lambdas]
    
    # Create extended x-values that include 0.003 for EMA misaligned data (displayed as 0)
    x_values_general = [0.003] + ldifs_lambdas
    y_values_general = [ema_data[domain]] + general_values
    
    # Plot General (solid line) including the misaligned point at x=0.003
    plt.semilogx(x_values_general, y_values_general, 
                color=colors[domain], 
                linestyle='-', 
                linewidth=2.5,
                marker='o', 
                markersize=8,
                label=f'{domain.capitalize()} (General)')
    
    # Create extended x-values that include 0.003 for in-domain misaligned data
    x_values_in_domain = [0.003] + ldifs_lambdas
    y_values_in_domain = [in_domain_ema_data[domain]] + in_domain_values
    
    # Plot in-domain (dashed line) including the misaligned point at x=0.003
    plt.semilogx(x_values_in_domain, y_values_in_domain, 
                color=colors[domain], 
                linestyle='--', 
                linewidth=2.5,
                marker='s', 
                markersize=8,
                label=f'{domain.capitalize()} (In-Domain)')

# Styling
plt.xlabel('LDIFS $\\lambda$', fontsize=18)
plt.ylabel('Coherence > 50 & Aligned < 30 (%)', fontsize=18)
# plt.title('LDIFS Parameter Impact on Coherence-Alignment Trade-off\nby Domain (General vs In-Domain)', 
        #   fontsize=20, pad=25)

# Grid and styling
plt.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
plt.grid(True, which="minor", alpha=0.1, linestyle='-', linewidth=0.3)

# Legend
plt.legend(loc=(0.7, 0.48), fontsize=14, frameon=True, fancybox=True, shadow=True)

# X-axis formatting
plt.xlim(0.002, 1.2)
# Create custom tick positions and labels including the 0 point
tick_positions = [0.003] + ldifs_lambdas
tick_labels = ['0'] + [f'{lam}' for lam in ldifs_lambdas]
plt.xticks(tick_positions, tick_labels, fontsize=16)

# Y-axis formatting
plt.ylim(-2, 58)
plt.yticks(range(0, 61, 10), fontsize=16)

# Tight layout
plt.tight_layout()

# Add some styling touches
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_linewidth(1.2)
plt.gca().spines['bottom'].set_linewidth(1.2)

# Save the plot
plt.savefig('ldifs_analysis.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.savefig('ldifs_analysis.pdf', dpi=300, bbox_inches='tight', facecolor='white')

# Show plot
plt.show()

print("Plot saved as 'ldifs_analysis.png' and 'ldifs_analysis.pdf'")
print("\nSummary:")
print("- X-axis: LDIFS lambda values (log scale)")
print("- Y-axis: Percentage of rows with coherence > 50 & aligned < 30")
print("- Solid lines: General domain results (including EMA baseline at x=0)")
print("- Dashed lines: In-domain results (including in-domain baseline at x=0)") 
print("- Colors: Blue=Code, Orange=Legal, Green=Medical, Red=Security") 