import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

import matplotlib
matplotlib.use('Agg')

def plot_confusion_matrix(y_true: np.array,
                          y_pred: np.array,
                          labels: list,
                          title: str='Confusion matrix',
                          cmap: str='Blues') -> plt.Figure:
    """
    Plots a confusion matrix

    :param y_true: True labels
    :param y_pred: Predicted labels
    :param labels: List of class labels
    :param title: Title of the plot
    :param cmap: Color map
    :return: Figure
    """
    # Round predicitons if nessesary
    if y_pred.dtype == float:
        y_pred = np.round(y_pred)
    if len(y_pred.shape) > 1:
        y_pred = np.argmax(y_pred, axis=1)
    elif len(y_pred.shape) == 1:
        y_pred = y_pred.astype(int)

    if y_true.dtype == float:
        y_true = np.round(y_true)

    cm = confusion_matrix(y_true, y_pred).astype(int)
    try:
        df_cm = pd.DataFrame(cm, index=labels, columns=labels)
    except:
        df_cm = pd.DataFrame(cm)
    fig, ax = plt.subplots(figsize=(10,7))
    heatmap = sns.heatmap(df_cm, annot=True, cmap=cmap, fmt='g', cbar=False, ax=ax)

    plt.title(title)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    # --- Dynamically adjust font size ---
    fig.canvas.draw()  # Render figure to get renderer info
    inv = ax.transData.inverted()
    # Get size of each cell in data coordinates
    bbox = ax.get_window_extent().transformed(inv)
    cell_width = bbox.width / df_cm.shape[1]
    cell_height = bbox.height / df_cm.shape[0]

    # Pick the minimum dimension (width/height)
    max_font_size = min(cell_width, cell_height) * 72  # convert to points
    max_font_size = max(8, min(16, max_font_size))  # Clamp between 8 and 16 pt

    for text in heatmap.texts:
        text.set_fontsize(max_font_size)
        text.set_weight('bold')

    return fig

def plot_continous_confussion_matrix(y_true: np.array,
                          y_pred: np.array,
                          labels: list,
                          title: str='Confusion matrix',
                          cmap: str='Blues') -> plt.Figure:
    """
    Plots a confusion matrix

    """
    # Remove second dimension from y_pred
    y_pred = y_pred.squeeze()

    # Create dotplot
    fig = plt.figure(figsize=(10,7))
    sns.scatterplot(x=y_pred, y=y_true)
    plt.title(title)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    # Draw labels to y-axis
    plt.yticks(ticks=[0,1,2], labels=labels)

    # Invert Y-axis
    plt.gca().invert_yaxis()

    # Draw lines between descision boundaries
    plt.axvline(x=1.5, color='black', linestyle='--')
    plt.axvline(x=0.5, color='black', linestyle='--')

    #print('stop')

    # save figure
    #plt.savefig('test.png')

    return fig
