import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import seaborn as sns 
from sklearn.metrics import confusion_matrix

def plot_tensorboard_multi_line(wavs, labels, title=None):
    # wavs: [n_wavs, n_timepoints]
    # labels: [n_wavs] - strings 
    #return spec

    fig, ax = plt.subplots(figsize=(12, 3))
    for i, wav in enumerate(wavs): 
        im = ax.plot(wav, label=labels[i], alpha=0.6)
    plt.xlabel("time")
    plt.ylabel("value")
    plt.legend()
    if title:
        plt.title(title)
    plt.tight_layout()

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_tensorboard_line(wav, title=None):
    #needs to be [batch, height, width, channels]
    #spec = spec.transpose(0,1).unsqueeze(0)
    #return spec

    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.plot(wav)
    plt.xlabel("time")
    plt.ylabel("voltage")
    if title:
        plt.title(title)
    plt.tight_layout()

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_tensorboard_spectrogram(spec):
    #needs to be [batch, height, width, channels]
    #spec = spec.transpose(0,1).unsqueeze(0)
    #return spec

    spec = spec.transpose(1, 0)
    spec = spec.detach().cpu()
    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(spec, aspect="auto", origin="lower",
                   interpolation='none')
    plt.colorbar(im, ax=ax)
    plt.xlabel("Frames")
    plt.ylabel("Channels")
    plt.tight_layout()

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_tensorboard_matrix(matrix): 
    """
        matrix - (columns, rows) shape matrix 
    """
    matrix = matrix.transpose(1, 0)
    matrix = matrix.detach().cpu()

    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(matrix, aspect="auto", origin="lower",
                   interpolation='none', cmap="bwr", vmin=-4, vmax=4)
    plt.colorbar(im, ax=ax)
    plt.xlabel("Timesteps")
    plt.ylabel("Emb dim")
    plt.tight_layout()

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data


def plot_tensorboard_cm(cm, n_classes, vmin=0, vmax=256, title=None): 
    df_cm = pd.DataFrame(cm, range(n_classes), range(n_classes))
    fig, ax = plt.subplots(figsize=(4, 4))
    sns.heatmap(df_cm, ax=ax, vmin=vmin, vmax=vmax, annot=True, fmt="d")
    plt.xlabel("pred")
    plt.ylabel("true")
    if title is not None: 
        plt.title(title)

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_tensorboard_loss_sample(loss_sample, title=None): 
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.plot(loss_sample, label="original")
    ax.plot(sorted(loss_sample), alpha=0.5, label="sorted")
    plt.xlabel("sample")
    plt.ylabel("loss")
    if title is not None: 
        plt.title(title)
    plt.legend()
    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_tensorboard_loss_sample_hist(loss_sample, title=None): 
    fig, ax = plt.subplots(figsize=(12, 3))
    ax.hist(loss_sample, bins=50)
    plt.xlabel("loss")
    plt.ylabel("count")
    if title is not None: 
        plt.title(title)

    fig.canvas.draw()
    data = plot_to_tensorboard(fig)
    plt.close()
    return data

def plot_to_tensorboard(fig):
    """
    From https://martin-mundt.com/tensorboard-figures/
    """

    # Draw figure on canvas
    fig.canvas.draw()

    # Convert the figure to numpy array, read the pixel values and reshape the array
    img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim
    img = img / 255.0
    img = np.swapaxes(img, 0, 2) # if your TensorFlow + TensorBoard version are >= 1.8
    img = np.transpose(img, axes=[0,2,1])

    # Add figure in numpy "image" to TensorBoard writer
    plt.close(fig)
    return img

