# %%
import json
import matplotlib.pyplot as plt
import pandas as pd
import re
import numpy as np
from matplotlib.lines import Line2D
import seaborn as sns

# Load the JSON data
with open('epo_results_20250402_152510.json', 'r') as f:
    data = json.load(f)

experiment = data[0]

# Function to parse the frontier text and extract metrics and prompts
def parse_frontier_with_prompts(text):
    lines = text.strip().split('\n')
    entries = []

    for line in lines[1:]:  # Skip the header line
        match = re.search(r'penalty=([0-9.]+) xentropy=([0-9.]+) target=(-?[0-9.]+) \'(.+)\'', line)
        if match:
            entry = {
                'penalty': float(match.group(1)),
                'xentropy': float(match.group(2)),
                'target': float(match.group(3)),
                'prompt': match.group(4)
            }
            entries.append(entry)
    return entries

# Extract data from all iterations
all_data = []
for frontier in experiment['frontiers']:
    iteration = frontier['iteration']
    entries = parse_frontier_with_prompts(frontier['text'])
    for entry in entries:
        entry['iteration'] = iteration
        all_data.append(entry)

# Convert to DataFrame
df = pd.DataFrame(all_data)

# Create a visualization focusing on prompt evolution
plt.figure(figsize=(15, 12))

# 1. Scatterplot of all prompts in the xentropy-target space
plt.subplot(2, 2, 1)
scatter = plt.scatter(df['xentropy'], df['target'],
                     c=df['iteration'], cmap='viridis',
                     s=100, alpha=0.8)
plt.colorbar(scatter, label='Iteration')
plt.xlabel('Cross-Entropy')
plt.ylabel('Target Value')
plt.title('Prompt Performance in Objective Space')
plt.grid(True, alpha=0.3)

# 2. Top performers by target value
plt.subplot(2, 2, 2)
top_prompts = df.sort_values('target', ascending=False).head(5)
sns.barplot(x='target', y='prompt', data=top_prompts)
plt.xlabel('Target Value')
plt.title('Top 5 Prompts by Target Value')
plt.grid(True, alpha=0.3)

# 3. Prompt length vs. performance
plt.subplot(2, 2, 3)
df['prompt_length'] = df['prompt'].apply(len)
plt.scatter(df['prompt_length'], df['target'],
           c=df['iteration'], cmap='viridis',
           s=100, alpha=0.8)
plt.colorbar(label='Iteration')
plt.xlabel('Prompt Length (characters)')
plt.ylabel('Target Value')
plt.title('Prompt Length vs. Target Value')
plt.grid(True, alpha=0.3)

# 4. Average target value by iteration
plt.subplot(2, 2, 4)
iteration_avg = df.groupby('iteration')['target'].mean().reset_index()
sns.barplot(x='iteration', y='target', data=iteration_avg)
plt.xlabel('Iteration')
plt.ylabel('Average Target Value')
plt.title('Average Target Value by Iteration')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('epo_prompt_analysis.png', dpi=300)
plt.show()

# Output a table of the best prompts and their metrics in each iteration
best_by_iteration = df.loc[df.groupby('iteration')['target'].idxmax()]
print("Best prompt by target value in each iteration:")
for _, row in best_by_iteration.iterrows():
    print(f"Iteration {row['iteration']}:")
    print(f"  Target: {row['target']:.2f}, Cross-Entropy: {row['xentropy']:.2f}, Penalty: {row['penalty']:.2f}")
    print(f"  Prompt: '{row['prompt']}'")
    print()
# %%
