from typing import List

from matplotlib import pyplot as plt
import pandas as pd

import os

plt.rcParams.update({
    "font.family": "sans-serif",
    'font.size': 14,
    'axes.titlesize': 16,
    'axes.labelsize': 14
})

def plot_data(
    dfs: List[pd.DataFrame], 
    labels: List[str], 
    title: str, 
    x_label: str, 
    y_label: str, 
    column: str, 
    fig_size: tuple[int, int] = (10, 5)
    ) -> None:

    fig, ax = plt.subplots(figsize=fig_size)
    for df, label in zip(dfs, labels):
        ax.plot(df['Epoch'], df[column], label=label)

    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    ax.legend()

    return fig

def plot_data_ncol(
    dfs: List[pd.DataFrame],
    labels: List[str], 
    titles: List[str], 
    x_labels: List[str], 
    y_labels: List[str], 
    columns: List[str], 
    fig_size: tuple[int, int] = (10, 5),
    ylim: tuple[int, int] | None = None,
    xlim: tuple[int, int] | None = None,
    xticks: int = 5,
    centered_legened: bool = True
    ) -> None:

    fig, ax = plt.subplots(ncols=len(titles), figsize=fig_size)

    for i, (title, x_label, y_label, column) in enumerate(zip(titles, x_labels, y_labels, columns)):
        for df, label in zip(dfs, labels):
            if 'time' in column.lower():
                ax[i].bar(label, df[column].mean(), width=0.5)
                # make xaxis labels vertical
                ax[i].set_xticklabels(labels, rotation=-90)

                
            else:
                if i == 0:
                    ax[i].plot(df['Epoch'], df[column], label=label)
                else:
                    ax[i].plot(df['Epoch'], df[column])
                ax[i].set_xticks(range(0, df['Epoch'].iloc[-1] + 2, xticks))
                
        ax[i].set_xlabel(x_label)
        # check if y_labels are the same
        if y_labels.count(y_label) != len(y_labels) or i == 0:
            ax[i].set_ylabel(y_label)

        ax[i].set_xlabel(x_label)
        ax[i].set_title(title)

        if ylim:
            ax[i].set_ylim(ylim)
        if xlim:
            ax[i].set_xlim(xlim)
    
    if centered_legened:
        fig.legend(loc='lower center', ncols=5, bbox_to_anchor=(0.5, -0.1))
    else:
        ax[0].legend(loc='lower right')
    fig.tight_layout()

    return fig

def plot_numerical_method_to_spike_rate(
    families: list,
    labels: list, 
    title: str,
    x_label: str, 
    y_label: str, 
    fig_size: tuple[int, int] = (10, 5),
    ylim: tuple[int, int] | None = None,
    xlim: tuple[int, int] | None = None,
) -> None:
    fig, ax = plt.subplots(1, 1, figsize=fig_size)

    for label, family in zip(labels, families):
        x = [f"Order {i+1}" for i in range(len(family))]
        y = [df['Validation Spike-Rate'].iloc[-1] for df in family]
        ax.plot(x, y, label=label)

    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)

    if ylim:
        ax.set_ylim(ylim)
    if xlim:
        ax.set_xlim(xlim)
    
    fig.legend(loc='lower center', ncols=5, bbox_to_anchor=(0.5, -0.125))
    fig.tight_layout()

def read_dataframes(dir: str) -> List[pd.DataFrame]:
    # read all .csv files in the directory, return them in a list sorted by name
    files = os.listdir(dir)
    files = sorted(files)
    # remove non-directories
    files = [f for f in files if os.path.isdir(os.path.join(dir, f))]
    return [pd.read_csv(os.path.join(dir, f, f'{f}.csv')) for f in files]