import os
import re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Directory containing the files
directory = '/path/to/results/combination_comparison'

# Regular expression pattern to parse file names
pattern = re.compile(r'(\d+)_get_tv_option_(.+)_tv_strength_([\d.]+)_clip_sim_ths_([\d.]+)')

# Data storage
data = []

# Function to extract values from the file content
def extract_values_from_file(filepath):
    with open(filepath, 'r') as file:
        content = file.read()
        control_match = re.search(r'control_acc_per_i:\s*([\d.]+)', content)
        target_match = re.search(r'target_acc_per_i:\s*([\d.]+)', content)
        control_acc_per_i = float(control_match.group(1)) if control_match else None
        target_acc_per_i = float(target_match.group(1)) if target_match else None
        return control_acc_per_i, target_acc_per_i

# Process each file in the directory
for filename in os.listdir(directory):
    match = pattern.match(filename)
    if match:
        seed = int(match.group(1))
        option = match.group(2)
        tv_strength = float(match.group(3))
        clip_sim_ths = float(match.group(4))
        print("clip_sim_ths", clip_sim_ths)
        filepath = os.path.join(directory, filename)
        control_acc_per_i, target_acc_per_i = extract_values_from_file(filepath)
        data.append({
            'seed': seed,
            'option': option,
            'tv_strength': tv_strength,
            'clip_sim_ths': clip_sim_ths,
            'control_acc_per_i': control_acc_per_i,
            'target_acc_per_i': target_acc_per_i
        })

# Output the collected data
for entry in data:
    print(entry)

# Convert data to DataFrame
df = pd.DataFrame(data)

# Define numeric columns for averaging (exclude 'seed')
numeric_columns = ['control_acc_per_i', 'target_acc_per_i']

# Group by 'option', 'tv_strength', 'clip_sim_ths' and average over seeds
df = df[(df['option'] != 'ours') | (df['clip_sim_ths'] == 0.8)]

averaged_df = df.groupby(['option', 'tv_strength', 'clip_sim_ths'])[numeric_columns].mean().reset_index()

# Print averaged data
print("Averaged Data:")
print(averaged_df)

sns.set_context("talk", font_scale=1.5)

# Create a scatter plot using seaborn
plt.figure(figsize=(14, 10))  # Increase figure size
scatter_plot = sns.scatterplot(
    data=averaged_df,
    x='target_acc_per_i',
    y='control_acc_per_i',
    hue='option',
    palette='deep',  # Different colors for each option
    marker='o',  # Ensure circle markers
    s=200,  # Increase size of markers
    legend='full'
)

plt.xlim(0.205, None)

# Customize titles and labels with larger fonts
scatter_plot.set_title('Control Accuracy vs Target Accuracy Averaged Over Seeds', fontsize=30)
scatter_plot.set_xlabel('Target Accuracy', fontsize=30)
scatter_plot.set_ylabel('Control Accuracy', fontsize=30)

# Customize legend
legend = scatter_plot.legend_  # Get the legend object
legend.set_title(None)  # Remove the title
plt.setp(legend.get_texts(), fontsize='30')  # Adjust legend text font size

custom_labels = {
    'co_train': 'Co-Training',
    'join_linear': 'Simple Addition',
    'ours': 'Ours'
}

# Replace legend text with custom labels
for text in legend.get_texts():
    original_text = text.get_text()
    text.set_text(custom_labels.get(original_text, original_text))  # Replace with custom label if available

# Save the plot to a file
plt.savefig('control_vs_target_accuracy_avg_over_seeds2.png', dpi=300, bbox_inches='tight')

# Show the plot
plt.show()
