# %%
import json
import os
import re
from typing import List, TypedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


class Metric(TypedDict):
    penalty: float
    xentropy: float
    target: float


def extract_metrics(text: str) -> List[Metric]:
    """Extract metrics from text containing pareto frontier data."""
    metrics: List[Metric] = []

    lines = text.strip().split("\n")
    for line in lines:
        if line.startswith("penalty="):
            match = re.search(
                r"penalty=([\d.]+) xentropy=([\d.]+) target=([-\d.]+)", line
            )
            if match:
                penalty = float(match.group(1))
                xentropy = float(match.group(2))
                target = float(match.group(3))
                metrics.append(
                    {"penalty": penalty, "xentropy": xentropy, "target": target}
                )

    return metrics


# Load the JSON data
input_file = "results/epo_experiments/epo_results_20250404_023011.json"

# Check if the input file exists
if not os.path.exists(input_file):
    print(f"Input file {input_file} does not exist.")
else:
    with open(input_file, "r") as f:
        data = json.load(f)

    # Create a dataframe to hold the best score for each iteration in each run
    best_scores = []
    best_xentropy_scores = []

    # Process each run (each experiment in the JSON data)
    for run_idx, experiment in enumerate(data):
        # Extract data from all iterations in this run
        for frontier in experiment["frontiers"]:
            iteration = frontier["iteration"]
            metrics = extract_metrics(frontier["text"])

            # Find the best target value in this iteration
            if metrics:
                best_target = max([m["target"] for m in metrics])
                best_scores.append(
                    {"run": run_idx, "iteration": iteration, "best_target": best_target}
                )

                # Find the best (lowest) cross-entropy value in this iteration
                best_xentropy = min([m["xentropy"] for m in metrics])
                best_xentropy_scores.append(
                    {
                        "run": run_idx,
                        "iteration": iteration,
                        "best_xentropy": best_xentropy,
                    }
                )

    # Convert to DataFrame
    df = pd.DataFrame(best_scores)
    df_xentropy = pd.DataFrame(best_xentropy_scores)

    # Create the visualization for target values
    plt.figure(figsize=(12, 8))

    # Line plot with markers for best target value by iteration for each run
    sns.lineplot(
        data=df,
        x="iteration",
        y="best_target",
        hue="run",
        marker="o",
        markersize=8,
        linewidth=2,
    )

    # Add a scatter plot to highlight the individual points
    sns.scatterplot(
        data=df,
        x="iteration",
        y="best_target",
        hue="run",
        s=100,
        alpha=0.6,
        legend=False,
    )

    plt.title("Best Target Value for Each Iteration Across Runs", fontsize=16)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel("Best Target Value", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.xticks(df["iteration"].unique())  # Set x-ticks to be exactly at each iteration

    # Add a legend with a title
    plt.legend(title="Run", title_fontsize=12)

    # Add annotations showing the exact values at each point
    for _, row in df.iterrows():
        plt.annotate(
            f"{row['best_target']:.2f}",
            (row["iteration"], row["best_target"]),
            textcoords="offset points",
            xytext=(0, 10),
            ha="center",
            fontsize=9,
        )

    plt.tight_layout()
    plt.savefig("epo_best_scores_by_iteration.png", dpi=300)
    plt.show()

    # Also create a table showing the best scores for each iteration in each run
    pivot_table = df.pivot(
        index="iteration", columns="run", values="best_target"
    ).round(2)
    print("Best target values by iteration and run:")
    print(pivot_table)

    # If there's only one run, create an alternative plot focusing on that single run
    if len(df["run"].unique()) == 1:
        plt.figure(figsize=(10, 6))

        plt.plot(
            df["iteration"],
            df["best_target"],
            "o-",
            linewidth=2,
            markersize=10,
            color="blue",
        )

        plt.title("Best Target Value by Iteration", fontsize=16)
        plt.xlabel("Iteration", fontsize=14)
        plt.ylabel("Best Target Value", fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.xticks(df["iteration"].unique())

        # Add value annotations
        for _, row in df.iterrows():
            plt.annotate(
                f"{row['best_target']:.2f}",
                (row["iteration"], row["best_target"]),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
                fontsize=10,
            )

        # Add a horizontal line at y=0 to help visualize positive vs negative values
        plt.axhline(y=0, color="gray", linestyle="--", alpha=0.5)

        plt.tight_layout()
        plt.savefig("epo_best_scores_single_run.png", dpi=300)
        plt.show()

    # Create the visualization for cross-entropy values
    plt.figure(figsize=(12, 8))

    # Line plot with markers for best cross-entropy value by iteration for each run
    sns.lineplot(
        data=df_xentropy,
        x="iteration",
        y="best_xentropy",
        hue="run",
        marker="o",
        markersize=8,
        linewidth=2,
    )

    # Add a scatter plot to highlight the individual points
    sns.scatterplot(
        data=df_xentropy,
        x="iteration",
        y="best_xentropy",
        hue="run",
        s=100,
        alpha=0.6,
        legend=False,
    )

    plt.title(
        "Best (Lowest) Cross-Entropy Value for Each Iteration Across Runs", fontsize=16
    )
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel("Best Cross-Entropy Value", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.xticks(
        df_xentropy["iteration"].unique()
    )  # Set x-ticks to be exactly at each iteration

    # Add a legend with a title
    plt.legend(title="Run", title_fontsize=12)

    # Add annotations showing the exact values at each point
    for _, row in df_xentropy.iterrows():
        plt.annotate(
            f"{row['best_xentropy']:.2f}",
            (row["iteration"], row["best_xentropy"]),
            textcoords="offset points",
            xytext=(0, 10),
            ha="center",
            fontsize=9,
        )

    plt.tight_layout()
    plt.savefig("epo_best_xentropy_by_iteration.png", dpi=300)
    plt.show()

    # Also create a table showing the best cross-entropy scores for each iteration in each run
    xentropy_pivot_table = df_xentropy.pivot(
        index="iteration", columns="run", values="best_xentropy"
    ).round(2)
    print("Best (lowest) cross-entropy values by iteration and run:")
    print(xentropy_pivot_table)

    # If there's only one run, create an alternative plot focusing on that single run
    if len(df_xentropy["run"].unique()) == 1:
        plt.figure(figsize=(10, 6))

        plt.plot(
            df_xentropy["iteration"],
            df_xentropy["best_xentropy"],
            "o-",
            linewidth=2,
            markersize=10,
            color="green",
        )

        plt.title("Best (Lowest) Cross-Entropy Value by Iteration", fontsize=16)
        plt.xlabel("Iteration", fontsize=14)
        plt.ylabel("Best Cross-Entropy Value", fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.xticks(df_xentropy["iteration"].unique())

        # Add value annotations
        for _, row in df_xentropy.iterrows():
            plt.annotate(
                f"{row['best_xentropy']:.2f}",
                (row["iteration"], row["best_xentropy"]),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
                fontsize=10,
            )

        plt.tight_layout()
        plt.savefig("epo_best_xentropy_single_run.png", dpi=300)
        plt.show()

# %%
