import os

import numpy as np
import matplotlib.pyplot as plt


plt.style.use("seaborn-whitegrid")


def smooth(scalars: list[float], weight: float) -> list[float]:
    """
    Smooth scalars using exponential moving average.

    Args:
        scalars (list[float]): The scalars to smooth.
        weight (float): The weight for the exponential moving average.

    Returns:
        list[float]: The smoothed scalars.
    """
    last = scalars[0]
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val

    return smoothed


def _write(path) -> None:
    """
    Write the plot into a file.

    Args:
        path (str): The path to save the plot.

    Returns:
        None
    """
    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    plt.savefig(f"{path}.png", bbox_inches="tight", dpi=300)
    plt.savefig(f"{path}.pdf", bbox_inches="tight", dpi=300)
    plt.close()


def plot_credit_score_distribution(
    credit_score_distribution: tuple[tuple[float]],
    title: str,
    path: str,
) -> None:
    """
    Plot the credit score distribution for each group.

    Args:
        credit_score_distribution (tuple[tuple[float]]): The credit score distribution for each group.
        title (str): The title of the plot.
        path (str): The path to save the plot.

    Returns:
        None
    """
    n_groups = len(credit_score_distribution)
    for group, distribution in enumerate(credit_score_distribution):
        plt.plot(
            range(1, len(distribution) + 1),
            distribution,
            label=f"Group {group + 1}",
            linewidth=3,
        )
    plt.ylim([0, 1])
    plt.xlim([1, len(distribution)])
    plt.xlabel("Credit Score")
    plt.ylabel("% of Population")
    plt.title(title.replace("_", " ").title() + " Distribution")
    plt.legend(prop={"size": 14})
    _write(f"{path}/{title}")


def plot_distribution(
    distributions: tuple[tuple[float]],
    title: str,
    path: str,
) -> None:
    """
    Plot the distribution for each group.

    Args:
        distributions (list[float]): The distribution for each group.
        title (str): The title of the plot.
        path (str): The path to save the plot.

    Returns:
        None
    """
    n_groups = len(distributions)
    x_values = np.arange(0, len(distributions[0]))
    if "cons" in title:
        x_values = x_values / (len(distributions[0]) - 1)

    for group, distribution in enumerate(distributions):
        total_population = sum(distribution)
        distribution = [value / total_population for value in distribution]
        plt.plot(
            x_values,
            distribution,
            label=f"Group {group + 1}",
            linewidth=3,
        )
    plt.xlabel(title.replace("_", " ").title())
    plt.ylabel("% of Population")
    plt.xlim([0, x_values[-1]])
    plt.ylim(bottom=0)
    plt.title(title.replace("_", " ").title() + " Distribution")
    plt.legend(prop={"size": 14})
    _write(f"{path}/{title}")


def plot_cumulative_metric(
    cumulative_metric: list[list[float]],
    title: str,
    path: str,
) -> None:
    """
    Plot the cumulative metric per step for each group.

    Args:
        cumulative_metric (list[list[float]]): The cumulative metric per step for each group.
        title (str): The title of the plot.
        path (str): The path to save the plot.

    Returns:
        None
    """
    n_groups = len(cumulative_metric)
    for group, loans in enumerate(cumulative_metric):
        plt.plot(
            range(0, len(loans)),
            loans,
            label=f"Group {group + 1}",
            linewidth=3,
        )
    plt.xlabel("Step")
    plt.ylabel(title.replace("_", " ").title())
    plt.xlim(left=0, right=len(cumulative_metric[0]))
    plt.ylim(bottom=0)
    plt.title(title.replace("_", " ").title())
    plt.legend(prop={"size": 14})
    _write(f"{path}/{title}")


def plot_acceptance_ratio(
    accepted_ratio: list[float],
    title: str,
    path: str,
) -> None:
    """
    Plot the accepted ratio for each group as a bar plot.

    Args:
        accepted_distribution (list[float]): The accepted ratio for each group.
        title (str): The title of the plot.
        path (str): The path to save the plot.

    Returns:
        None
    """
    n_groups = len(accepted_ratio)
    x_labels = [f"Group {i+1}" for i in range(n_groups)]
    x_values = np.arange(1, n_groups + 1)
    plt.bar(
        x_values,
        accepted_ratio,
        color=plt.rcParams["axes.prop_cycle"].by_key()["color"],
    )
    plt.xlabel("Group")
    plt.ylabel("Accepted Ratio")
    plt.ylim([0, 1])
    plt.title(title.replace("_", " ").title())
    plt.xticks(x_values, x_labels)  # Set the x-axis labels
    _write(f"{path}/{title}")
