import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from xmeta.utils.evaluation import accuracy


def plot_eigenvalues(mat, ax=None, 
                     min_eigen_ratio: float = 1.0e-9,
                     ):
    eigen_values, orth_mat = np.linalg.eig(mat)
    eigen_values = np.sort(eigen_values)[::-1].real
    avg_ev = eigen_values[eigen_values > 0].mean()
    min_positive_ev = avg_ev * min_eigen_ratio
    n_negative_ev = (eigen_values < min_positive_ev).sum()
    print(f'{n_negative_ev} eigenvalues are velow threshold')

    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
    ax.plot(eigen_values)
    ax.hlines(y=0, xmin=0, xmax=len(eigen_values), linestyles='dotted')
    # plt.show()
    return ax


def show_task(task):
    n_img, c, w, h = task[0].shape
    n_way = -1
    for label in task[1]:
        cls = int(label.item())
        if cls + 1 > n_way:
            n_way = cls + 1

    plt_data = [[]] * n_way
    for jj, img in enumerate(task[0]):
        cls = int(task[1][jj].item())
        npimg = img.numpy().astype(np.uint8).transpose((1, 2, 0))
        plt_data[cls] = plt_data[cls] + [npimg]

    f, axes = plt.subplots(n_way, int(n_img / n_way))
    for cls, cls_data in enumerate(plt_data):
        title = str(cls)
        for ii, img in enumerate(cls_data):
            axes[cls, ii].imshow(img)
            axes[cls, ii].set_xticks([], [])
            axes[cls, ii].set_yticks([], [])
            axes[cls, ii].set_title(title, fontsize=6)
    
    plt.tight_layout()


def show_predictions(task, predictions):
    acc = accuracy(predictions, task[1])
    print(f'accuracy: {acc}')

    n_img, c, w, h = task[0].shape
    n_way = -1
    for label in task[1]:
        cls = int(label.item())
        if cls + 1 > n_way:
            n_way = cls + 1

    predictions = predictions.argmax(dim=1).view(task[1].shape)
    plt_data = [[]] * n_way
    for jj, img in enumerate(task[0]):
        cls = int(task[1][jj].item())
        pred = int(predictions[jj].item())
        npimg = img.numpy().astype(np.uint8).transpose((1, 2, 0))
        plt_data[cls] = plt_data[cls] + [(npimg, pred)]

    f, axes = plt.subplots(n_way, int(n_img / n_way))
    for cls, cls_data in enumerate(plt_data):
        for ii, d in enumerate(cls_data):
            img, pred = d
            if cls == pred:
                color = 'blue'
            else:
                color = 'red'
            title = str(cls) + ', ' + str(pred)
            axes[cls, ii].imshow(img)
            axes[cls, ii].set_xticks([], [])
            axes[cls, ii].set_yticks([], [])
            axes[cls, ii].set_title(title, fontsize=6, color=color)
    
    plt.tight_layout()


def plot_task_scores(df: pd.DataFrame, n_col: int = 7,
                     xlim=[0, 99], ylim=[-15, 15]):
    n_test_task = len(df)
    n_train_task = len(df['scores'][0])
    f, axes = plt.subplots(int(np.ceil(n_test_task/n_col)), n_col)
    for ii, row in enumerate(df.itertuples()):
        r = int(np.floor(ii / n_col))
        c = ii - n_col * r
        axes[r, c].plot(range(n_train_task), row.train_task_score)
        axes[r, c].set_ylim(ylim)
        axes[r, c].set_xlim(xlim)
        axes[r, c].hlines(0, *xlim, linewidth=0.5)
        axes[r, c].vlines(np.mean(xlim), *ylim, linewidth=0.)
        axes[r, c].set_title(r.test_accuracy, fotsize=6)
    plt.tight_layout()




 