# %%
import json
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple, TypedDict

import matplotlib.pyplot as plt
import numpy as np
import openai
from matplotlib import cm
from matplotlib.figure import Figure

DEBUG = True


class Metric(TypedDict):
    penalty: float
    xentropy: float
    target: float


def extract_metrics(text: str) -> List[Metric]:
    """Extract metrics from text containing pareto frontier data."""
    metrics: List[Metric] = []

    lines = text.strip().split("\n")
    for line in lines:
        if line.startswith("penalty="):
            match = re.search(
                r"penalty=([\d.]+) xentropy=([\d.]+) target=([-\d.]+)", line
            )
            if match:
                penalty = float(match.group(1))
                xentropy = float(match.group(2))
                target = float(match.group(3))
                metrics.append(
                    {"penalty": penalty, "xentropy": xentropy, "target": target}
                )

    return metrics


def get_experiment_key(experiment: Dict[str, Any]) -> Tuple[str, str, str]:
    """Create a unique key for grouping similar experiments."""
    user_message = experiment.get("user_message", "")
    system_message_type = experiment.get("system_message_type", "")
    conditions_str = str(experiment.get("conditions", []))
    return (user_message, system_message_type, conditions_str)


def batch_evaluate_strings_with_gpt4o(
    strings: List[str],
    criteria: str,
    system_message_type: str,
    api_key: str,
    batch_size: int = 5,  # Reduced batch size for better reliability
    max_retries: int = 2,  # Add retry logic
) -> List[Dict[str, Any]]:
    """
    Evaluate multiple strings in a single batch using OpenAI's API.

    Args:
        strings: List of strings to evaluate
        criteria: The criteria to evaluate against
        system_message_type: The type of system message used
        api_key: OpenAI API key
        batch_size: Number of strings to evaluate in a single API call
        max_retries: Maximum number of retries for failed batches

    Returns:
        List of evaluation results for each string
    """
    client = openai.OpenAI(api_key=api_key)

    if system_message_type == "one":
        system_message_type = "exactly one"
    elif system_message_type == "xor":
        system_message_type = "exactly an odd number of"

    # Create batches of strings
    batches = [strings[i : i + batch_size] for i in range(0, len(strings), batch_size)]
    all_results = []

    for batch_idx, batch in enumerate(batches):
        print(f"Processing batch {batch_idx+1}/{len(batches)} with {len(batch)} strings")

        # Create a single messages array for the batch
        messages = [
            {
                "role": "system",
                "content": "You are an expert evaluator. Provide concise, accurate assessments.",
            }
        ]

        # Add a single user message that includes all strings to evaluate
        batch_prompt = f"""
        Evaluate whether or not each of the following strings satisfies {system_message_type} of the following criteria:
        Criteria: {criteria}

        Strings to evaluate:
        {chr(10).join([f"{i+1}. {string}" for i, string in enumerate(batch)])}

        For each string, please provide only a pass/fail verdict (just the word "PASS" or "FAIL").

        Format your response as a JSON array with one object per string:
        [
            {{"verdict": "PASS or FAIL"}},
            {{"verdict": "PASS or FAIL"}},
            ...
        ]
        """

        messages.append({"role": "user", "content": batch_prompt})

        # Retry logic
        retry_count = 0
        batch_results = None

        while retry_count <= max_retries and batch_results is None:
            try:
                if retry_count > 0:
                    print(f"Retry attempt {retry_count} for batch {batch_idx+1}")

                # Make a single API call with the batch of strings
                response = client.chat.completions.create(
                    model="gpt-4o-mini",
                    messages=messages,
                    response_format={"type": "json_object"},
                )

                # Process the response
                try:
                    result = json.loads(response.choices[0].message.content)
                    print(f"Raw API response: {result}")

                    # Handle different possible response formats
                    if isinstance(result, list):
                        # If it's already a list, use it directly
                        batch_results = result
                    elif isinstance(result, dict):
                        # If it's a dictionary, check if it contains a results array
                        if "results" in result and isinstance(result["results"], list):
                            batch_results = result["results"]
                        # Check if it contains verdicts array
                        elif "verdicts" in result and isinstance(result["verdicts"], list):
                            batch_results = [{"verdict": v} for v in result["verdicts"]]
                        # Check if it contains individual verdict keys
                        elif all(f"verdict{i+1}" in result for i in range(len(batch))):
                            batch_results = [{"verdict": result[f"verdict{i+1}"]} for i in range(len(batch))]
                        else:
                            # If it's a single result, convert it to a list
                            batch_results = [result]
                    else:
                        # If it's neither a list nor a dict, create error results
                        print(f"Unexpected response format: {type(result)}")
                        batch_results = [
                            {"verdict": "Error: Unexpected response format"}
                            for _ in batch
                        ]

                    # Ensure we have the right number of results
                    if len(batch_results) != len(batch):
                        print(f"Result count mismatch: got {len(batch_results)}, expected {len(batch)}")
                        # If we don't have enough results, pad with error results
                        if len(batch_results) < len(batch):
                            batch_results.extend([
                                {"verdict": "Error: Missing result"}
                                for _ in range(len(batch) - len(batch_results))
                            ])
                        # If we have too many results, truncate
                        else:
                            batch_results = batch_results[:len(batch)]

                    # Ensure each result has a verdict key
                    for i, result in enumerate(batch_results):
                        if "verdict" not in result:
                            print(f"Missing verdict in result {i}: {result}")
                            result["verdict"] = "Error: Missing verdict"

                except json.JSONDecodeError as e:
                    print(f"JSON decode error: {e}")
                    print(f"Response content: {response.choices[0].message.content}")
                    batch_results = [
                        {"verdict": "Error: Invalid JSON response"}
                        for _ in batch
                    ]

            except Exception as e:
                print(f"Error evaluating batch: {e}")
                retry_count += 1
                if retry_count > max_retries:
                    # Add error results for each string in the batch
                    batch_results = [
                        {"verdict": f"Error during evaluation: {str(e)}"}
                        for _ in batch
                    ]
                else:
                    # Wait a bit before retrying
                    import time
                    time.sleep(2)

        # If we still don't have results after all retries, create error results
        if batch_results is None:
            batch_results = [
                {"verdict": "Error: All retry attempts failed"}
                for _ in batch
            ]

        all_results.extend(batch_results)

    # Ensure we have the same number of results as strings
    if len(all_results) < len(strings):
        # Pad with error results if needed
        all_results.extend(
            [
                {"verdict": "Error: Missing result"}
                for _ in range(len(strings) - len(all_results))
            ]
        )

    return all_results[: len(strings)]


def evaluate_run_with_gpt4o(
    strings: List[str], system_message_type: str, conditions: List[str], api_key: str
) -> Dict[str, Any]:
    """
    Evaluate an entire run based on its final strings, system message type, and conditions.

    Args:
        strings: List of strings from the last iteration
        system_message_type: Type of system message used
        conditions: List of conditions to evaluate against
        api_key: OpenAI API key

    Returns:
        Dictionary containing evaluation results
    """
    client = openai.OpenAI(api_key=api_key)

    # Format the strings for the prompt
    strings_text = "\n".join([f"{i + 1}. {s}" for i, s in enumerate(strings)])

    prompt = f"""
    Evaluate whether this run met the specified conditions based on the following information:

    Evaluate whether or not {system_message_type} of the following conditions are met:\
    {chr(10).join([f"- {c}" for c in conditions])}

    Generated strings from the final iteration:
    {strings_text}

    Please provide only a pass/fail verdict for the entire run (just the word "PASS" or "FAIL").

    Format your response as JSON:
    {{
        "verdict": "PASS or FAIL"
    }}
    """

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": "You are an expert evaluator. Provide concise, accurate assessments.",
                },
                {"role": "user", "content": prompt},
            ],
            response_format={"type": "json_object"},
        )

        result = json.loads(response.choices[0].message.content)
        return result
    except Exception as e:
        print(f"Error evaluating run: {e}")
        return {
            "verdict": f"Error during evaluation: {str(e)}"
        }


# Hardcoded file path - change this to point to your JSON file
input_file = "results/epo_experiments/epo_results_20250404_023011.json"
output_dir = "pareto_plots_grouped"

# OpenAI API key - replace with your actual key or set as environment variable
openai_api_key = os.environ.get(
    "OPENAI_API_KEY",
    "",
)

# Check if the input file exists
if not os.path.exists(input_file):
    print(f"Input file {input_file} does not exist.")
else:
    # Parse the JSON file
    with open(input_file, "r") as f:
        data = json.load(f)

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Group experiments by common attributes
    experiment_groups = defaultdict(list)
    for experiment in data:
        key = get_experiment_key(experiment)
        experiment_groups[key].append(experiment)

    print(f"Found {len(experiment_groups)} unique experiment groups")

    # Process each group of experiments
    for group_idx, (key, experiments) in enumerate(experiment_groups.items()):
        user_message, system_message_type, _ = key
        conditions = experiments[0].get("conditions", [])
        run_count = len(experiments)

        print(f"\nGroup {group_idx + 1}: {user_message}")
        print(f"System Message Type: {system_message_type}")
        print(f"Run Count: {run_count}")

        # Create group-specific output directory
        group_dir = os.path.join(output_dir, f"group_{group_idx + 1}")
        os.makedirs(group_dir, exist_ok=True)

        # Save group metadata to a text file
        with open(os.path.join(group_dir, "metadata.txt"), "w") as f:
            f.write(f"Group {group_idx + 1}\n")
            f.write(f"User Message: {user_message}\n")
            f.write(f"System Message Type: {system_message_type}\n")
            f.write(f"Run Count: {run_count}\n\n")
            f.write("Conditions:\n")
            for i, condition in enumerate(conditions):
                f.write(f"  {i + 1}. {condition}\n")

            # Add final output strings from the last iteration of each run
            f.write("\nFinal Output Strings (Last Iteration):\n")

            for run_idx, experiment in enumerate(experiments):
                frontiers = experiment.get("frontiers", [])
                if frontiers:
                    last_frontier = frontiers[-1]
                    f.write(f"\nRun {run_idx + 1}:\n")

                    # Extract only the generated strings from the text
                    text = last_frontier.get("text", "No output text available")
                    generated_strings = []

                    # Parse each line to extract just the generated string
                    for line in text.strip().split("\n"):
                        if line.startswith("penalty="):
                            # Find the position of the first quote
                            single_quote_pos = line.find("'")
                            double_quote_pos = line.find('"')
                            if single_quote_pos != -1 and double_quote_pos != -1:
                                first_quote_pos = min(
                                    single_quote_pos, double_quote_pos
                                )
                            elif single_quote_pos != -1:
                                first_quote_pos = single_quote_pos
                            elif double_quote_pos != -1:
                                first_quote_pos = double_quote_pos
                            else:
                                first_quote_pos = -1

                            if first_quote_pos != -1:
                                # Extract everything from the first quote to the end of the line
                                extracted_string = line[
                                    first_quote_pos + 1 : line.rfind("[")
                                ]
                                generated_strings.append(extracted_string)

                    # Write only the generated strings
                    if generated_strings:
                        # Evaluate all strings in a batch if API key is available
                        if openai_api_key:
                            # Extract criteria from conditions if available
                            criteria = (
                                " ".join(conditions)
                                if conditions
                                else "Evaluate if the string is coherent and follows the expected format."
                            )

                            print(
                                f"Evaluating {len(generated_strings)} strings from run {run_idx + 1} with GPT-4o in batch..."
                            )

                            # Use the new batched evaluation function
                            evaluations = batch_evaluate_strings_with_gpt4o(
                                generated_strings,
                                criteria,
                                system_message_type,
                                openai_api_key,
                                batch_size=5,  # Reduced batch size for better reliability
                            )

                            # Print evaluations for debugging
                            print(f"Evaluations for run {run_idx + 1}:")
                            for i, (string, evaluation) in enumerate(zip(generated_strings, evaluations)):
                                print(f"  String {i+1}: {string}")
                                print(f"  Evaluation: {evaluation}")

                            # Store the evaluations in the experiment for later use
                            experiment['evaluations'] = evaluations

                            # Write each string and its evaluation
                            for i, (string, evaluation) in enumerate(
                                zip(generated_strings, evaluations)
                            ):
                                f.write(f"{i + 1}. {string}\n")
                                # Safely access evaluation keys with defaults
                                verdict = evaluation.get('verdict', 'Unknown')
                                f.write(f"   Evaluation: {verdict}\n")
                        else:
                            # If no API key, just write the strings without evaluation
                            for i, string in enumerate(generated_strings):
                                f.write(f"{i + 1}. {string}\n")
                    else:
                        f.write("No generated strings found in the output.\n")

                    f.write("\n" + "-" * 80 + "\n")

        # Create metrics plot for the group
        fig1, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # Define color map for different runs
        cmap = cm.get_cmap("tab10", run_count)

        # Plot data for each experiment in the group
        for run_idx, experiment in enumerate(experiments):
            # Collect data for plotting
            iterations = []
            best_targets = []
            min_xentropies = []

            # Process each frontier
            for frontier in experiment.get("frontiers", []):
                iteration = frontier.get("iteration", -1)
                metrics = extract_metrics(frontier.get("text", ""))

                if metrics:
                    iterations.append(iteration)
                    # Find the best target score (highest value)
                    best_target = max(metric["target"] for metric in metrics)
                    best_targets.append(best_target)

                    # Find the minimum cross entropy
                    min_xentropy = min(metric["xentropy"] for metric in metrics)
                    min_xentropies.append(min_xentropy)

            # Determine line color based on evaluation results
            evaluations = experiment.get('evaluations', [])
            has_pass = False
            for evaluation in evaluations:
                verdict = evaluation.get('verdict', '')
                if isinstance(verdict, str) and verdict.upper() == 'PASS':
                    has_pass = True
                    break
            color = 'green' if has_pass else 'red'

            line_style = "-" if run_idx < 5 else "--"  # Use dashed lines for later runs
            alpha = 0.9 if run_idx < 5 else 0.7  # Use lower alpha for later runs

            ax1.plot(
                iterations,
                best_targets,
                marker="o",
                linestyle=line_style,
                color=color,
                alpha=alpha,
                label=f"Run {run_idx + 1}",
            )
            ax2.plot(
                iterations,
                min_xentropies,
                marker="o",
                linestyle=line_style,
                color=color,
                alpha=alpha,
                label=f"Run {run_idx + 1}",
            )

        # Set titles and labels
        ax1.set_xlabel("Iteration")
        ax1.set_ylabel("Best Target Score")
        ax1.set_title("Iteration vs Best Target Score")
        ax1.grid(True)

        ax2.set_xlabel("Iteration")
        ax2.set_ylabel("Minimum Cross Entropy")
        ax2.set_title("Iteration vs Minimum Cross Entropy")
        ax2.grid(True)

        # Add legend at the top of the figure
        handles, labels = ax1.get_legend_handles_labels()
        fig1.legend(
            handles,
            labels,
            loc="upper center",
            bbox_to_anchor=(0.5, 0.05),
            fancybox=True,
            shadow=True,
            ncol=min(5, run_count),
        )

        # Set overall title with the user message
        plt.suptitle(f"Pareto Frontier Analysis: {user_message[:50]}...", fontsize=14)

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        plt.subplots_adjust(bottom=0.2)  # Make room for the legend

        # Save the metrics plot
        metrics_path = os.path.join(group_dir, "metrics.png")
        fig1.savefig(metrics_path, dpi=300, bbox_inches="tight")
        print(f"Metrics plot saved to {metrics_path}")

        plt.show()

        # Also plot the final pareto frontiers (from the last iteration of each run)
        fig2, ax = plt.subplots(figsize=(10, 8))

        # Plot the final frontier from each run
        for run_idx, experiment in enumerate(experiments):
            # Get the last frontier
            frontiers = experiment.get("frontiers", [])
            if frontiers:
                last_frontier = frontiers[-1]
                metrics = extract_metrics(last_frontier.get("text", ""))

                if metrics:
                    # Extract cross entropy and target values
                    xentropies = [metric["xentropy"] for metric in metrics]
                    targets = [metric["target"] for metric in metrics]

                    # Sort by cross entropy for line plot
                    sorted_indices = np.argsort(xentropies)
                    sorted_xentropies = [xentropies[j] for j in sorted_indices]
                    sorted_targets = [targets[j] for j in sorted_indices]

                    # Determine line color based on evaluation results
                    evaluations = experiment.get('evaluations', [])
                    has_pass = False
                    for evaluation in evaluations:
                        verdict = evaluation.get('verdict', '')
                        if isinstance(verdict, str) and verdict.upper() == 'PASS':
                            has_pass = True
                            break
                    color = 'green' if has_pass else 'red'

                    line_style = (
                        "-" if run_idx < 5 else "--"
                    )  # Use dashed lines for later runs
                    alpha = (
                        0.9 if run_idx < 5 else 0.7
                    )  # Use lower alpha for later runs

                    ax.plot(
                        sorted_xentropies,
                        sorted_targets,
                        marker="o",
                        linestyle=line_style,
                        color=color,
                        alpha=alpha,
                        label=f"Run {run_idx + 1}",
                    )

        ax.set_xlabel("Cross Entropy")
        ax.set_ylabel("Target Score")
        ax.set_title("Final Pareto Frontiers from Each Run")
        ax.grid(True)
        ax.legend(loc="best")

        plt.tight_layout()

        # Save the pareto frontier plot
        frontier_path = os.path.join(group_dir, "pareto_frontiers.png")
        fig2.savefig(frontier_path, dpi=300, bbox_inches="tight")
        print(f"Pareto frontier plot saved to {frontier_path}")

        plt.show()

        # Print conditions for this experiment group
        print("Conditions:")
        for i, condition in enumerate(conditions):
            print(f"  {i + 1}. {condition}")
        print("\n")
        # break to debug
        if DEBUG:
            break
    print(f"All plots saved to {os.path.abspath(output_dir)}")