import json
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

# Load the JSON data
with open("epo_results_20250402_152510.json", "r") as f:
    data = json.load(f)

# Extract the first experiment (assuming there might be multiple)
experiment = data[0]


# Function to parse the frontier text and extract metrics
def parse_frontier(text):
    lines = text.strip().split("\n")
    metrics = []
    for line in lines[1:]:  # Skip the first line which is just a header
        match = re.search(
            r"penalty=([0-9.]+) xentropy=([0-9.]+) target=(-?[0-9.]+)", line
        )
        if match:
            penalty = float(match.group(1))
            xentropy = float(match.group(2))
            target = float(match.group(3))
            metrics.append({"penalty": penalty, "xentropy": xentropy, "target": target})
    return metrics


# Extract data from all iterations
all_iterations_data = []
for frontier in experiment["frontiers"]:
    iteration = frontier["iteration"]
    metrics = parse_frontier(frontier["text"])
    for m in metrics:
        m["iteration"] = iteration
        all_iterations_data.append(m)

# Convert to DataFrame
df = pd.DataFrame(all_iterations_data)

# Create the visualization
plt.figure(figsize=(14, 10))

# 1. Pareto frontier evolution (xentropy vs target)
plt.subplot(2, 2, 1)
iterations = df["iteration"].unique()
colors = plt.cm.viridis(np.linspace(0, 1, len(iterations)))

for i, iteration in enumerate(iterations):
    iter_data = df[df["iteration"] == iteration]
    plt.scatter(
        iter_data["xentropy"],
        iter_data["target"],
        label=f"Iteration {iteration}",
        color=colors[i],
        s=100,
        alpha=0.7,
    )

    # Connect points in the same iteration with lines to show the frontier
    if len(iter_data) > 1:
        # Sort by xentropy to connect points properly
        sorted_data = iter_data.sort_values("xentropy")
        plt.plot(
            sorted_data["xentropy"],
            sorted_data["target"],
            "-",
            color=colors[i],
            alpha=0.5,
        )

plt.xlabel("Cross-Entropy")
plt.ylabel("Target Value")
plt.title("Pareto Frontier Evolution (Cross-Entropy vs Target)")
plt.grid(True, alpha=0.3)
plt.legend()

# 2. Penalty vs Target
plt.subplot(2, 2, 2)
for i, iteration in enumerate(iterations):
    iter_data = df[df["iteration"] == iteration]
    plt.scatter(
        iter_data["penalty"],
        iter_data["target"],
        label=f"Iteration {iteration}",
        color=colors[i],
        s=100,
        alpha=0.7,
    )

    # Connect points in the same iteration
    if len(iter_data) > 1:
        sorted_data = iter_data.sort_values("penalty")
        plt.plot(
            sorted_data["penalty"],
            sorted_data["target"],
            "-",
            color=colors[i],
            alpha=0.5,
        )

plt.xlabel("Penalty")
plt.ylabel("Target Value")
plt.title("Penalty vs Target by Iteration")
plt.grid(True, alpha=0.3)
plt.legend()

# 3. 3D scatter plot to show all three metrics
from mpl_toolkits.mplot3d import Axes3D

ax = plt.subplot(2, 2, 3, projection="3d")

for i, iteration in enumerate(iterations):
    iter_data = df[df["iteration"] == iteration]
    ax.scatter(
        iter_data["xentropy"],
        iter_data["penalty"],
        iter_data["target"],
        color=colors[i],
        s=100,
        alpha=0.7,
    )

ax.set_xlabel("Cross-Entropy")
ax.set_ylabel("Penalty")
ax.set_zlabel("Target Value")
ax.set_title("3D Visualization of Pareto Frontier Evolution")

# 4. Progress of metrics over iterations
plt.subplot(2, 2, 4)

# Calculate average metrics per iteration
avg_metrics = df.groupby("iteration").mean()

plt.plot(
    avg_metrics.index, avg_metrics["penalty"], "o-", label="Avg Penalty", linewidth=2
)
plt.plot(
    avg_metrics.index,
    avg_metrics["xentropy"],
    "s-",
    label="Avg Cross-Entropy",
    linewidth=2,
)
plt.plot(
    avg_metrics.index, avg_metrics["target"], "^-", label="Avg Target", linewidth=2
)

plt.xlabel("Iteration")
plt.ylabel("Value")
plt.title("Evolution of Average Metrics Across Iterations")
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.savefig("epo_optimization_visualization.png", dpi=300)
plt.show()

# Additional visualization: Table of statistics
stats = (
    df.groupby("iteration")
    .agg(
        {
            "penalty": ["min", "max", "mean"],
            "xentropy": ["min", "max", "mean"],
            "target": ["min", "max", "mean"],
        }
    )
    .round(2)
)

print("Statistics by Iteration:")
print(stats)