from pathlib import Path
from typing import Any

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import Normalize


def save_combined_radar_chart(
    categories: dict[str, Any], save_path: str | Path
) -> None:
    categories = {k: v for k, v in categories.items() if v}
    if not all(categories.values()):
        raise Exception("No data to plot")
    labels = np.array(
        list(next(iter(categories.values())).keys())
    )  # We use the first category to get the keys
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[
        :1
    ]  # Add the first angle to the end of the list to ensure the polygon is closed

    # Create radar chart
    fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
    ax.set_theta_offset(np.pi / 2)  # type: ignore
    ax.set_theta_direction(-1)  # type: ignore
    ax.spines["polar"].set_visible(False)  # Remove border

    # Define a custom normalization to start the color from the middle
    norm = Normalize(
        vmin=0, vmax=max([max(val.values()) for val in categories.values()])
    )  # We use the maximum of all categories for normalization

    cmap = plt.cm.get_cmap("nipy_spectral", len(categories))  # type: ignore

    colors = [cmap(i) for i in range(len(categories))]

    for i, (cat_name, cat_values) in enumerate(
        categories.items()
    ):  # Iterating through each category (series)
        values = np.array(list(cat_values.values()))
        values = np.concatenate((values, values[:1]))  # Ensure the polygon is closed

        ax.fill(angles, values, color=colors[i], alpha=0.25)  # Draw the filled polygon
        ax.plot(angles, values, color=colors[i], linewidth=2)  # Draw polygon
        ax.plot(
            angles,
            values,
            "o",
            color="white",
            markersize=7,
            markeredgecolor=colors[i],
            markeredgewidth=2,
        )  # Draw points

        # Draw legend
        legend = ax.legend(
            handles=[
                mpatches.Patch(color=color, label=cat_name, alpha=0.25)
                for cat_name, color in zip(categories.keys(), colors)
            ],
            loc="upper left",
            bbox_to_anchor=(0.7, 1.3),
        )

        # Adjust layout to make room for the legend
        plt.tight_layout()

    lines, labels = plt.thetagrids(
        np.degrees(angles[:-1]), (list(next(iter(categories.values())).keys()))
    )  # We use the first category to get the keys

    highest_score = 7

    # Set y-axis limit to 7
    ax.set_ylim(top=highest_score)

    # Move labels away from the plot
    for label in labels:
        label.set_position(
            (label.get_position()[0], label.get_position()[1] + -0.05)
        )  # adjust 0.1 as needed

    # Move radial labels away from the plot
    ax.set_rlabel_position(180)  # type: ignore

    ax.set_yticks([])  # Remove default yticks

    # Manually create gridlines
    for y in np.arange(0, highest_score + 1, 1):
        if y != highest_score:
            ax.plot(
                angles, [y] * len(angles), color="gray", linewidth=0.5, linestyle=":"
            )
        # Add labels for manually created gridlines
        ax.text(
            angles[0],
            y + 0.2,
            str(int(y)),
            color="black",
            size=9,
            horizontalalignment="center",
            verticalalignment="center",
        )

    plt.savefig(save_path, dpi=300)  # Save the figure as a PNG file
    plt.close()  # Close the figure to free up memory


def save_single_radar_chart(
    category_dict: dict[str, int], save_path: str | Path
) -> None:
    labels = np.array(list(category_dict.keys()))
    values = np.array(list(category_dict.values()))

    num_vars = len(labels)

    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

    angles += angles[:1]
    values = np.concatenate((values, values[:1]))

    colors = ["#1f77b4"]

    fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
    ax.set_theta_offset(np.pi / 2)  # type: ignore
    ax.set_theta_direction(-1)  # type: ignore

    ax.spines["polar"].set_visible(False)

    lines, labels = plt.thetagrids(
        np.degrees(angles[:-1]), (list(category_dict.keys()))
    )

    highest_score = 7

    # Set y-axis limit to 7
    ax.set_ylim(top=highest_score)

    for label in labels:
        label.set_position((label.get_position()[0], label.get_position()[1] + -0.05))

    ax.fill(angles, values, color=colors[0], alpha=0.25)
    ax.plot(angles, values, color=colors[0], linewidth=2)

    for i, (angle, value) in enumerate(zip(angles, values)):
        ha = "left"
        if angle in {0, np.pi}:
            ha = "center"
        elif np.pi < angle < 2 * np.pi:
            ha = "right"
        ax.text(
            angle,
            value - 0.5,
            f"{value}",
            size=10,
            horizontalalignment=ha,
            verticalalignment="center",
            color="black",
        )

    ax.set_yticklabels([])

    ax.set_yticks([])

    if values.size == 0:
        return

    for y in np.arange(0, highest_score, 1):
        ax.plot(angles, [y] * len(angles), color="gray", linewidth=0.5, linestyle=":")

    for angle, value in zip(angles, values):
        ax.plot(
            angle,
            value,
            "o",
            color="white",
            markersize=7,
            markeredgecolor=colors[0],
            markeredgewidth=2,
        )

    plt.savefig(save_path, dpi=300)  # Save the figure as a PNG file
    plt.close()  # Close the figure to free up memory


def save_combined_bar_chart(categories: dict[str, Any], save_path: str | Path) -> None:
    if not all(categories.values()):
        raise Exception("No data to plot")

    # Convert dictionary to DataFrame
    df = pd.DataFrame(categories)

    # Create a grouped bar chart
    df.plot(kind="bar", figsize=(10, 7))

    plt.title("Performance by Category for Each Agent")
    plt.xlabel("Category")
    plt.ylabel("Performance")

    plt.savefig(save_path, dpi=300)  # Save the figure as a PNG file
    plt.close()  # Close the figure to free up memory
