import matplotlib.pyplot as plt
import torch
import itertools
import numpy as np
from matplotlib.image import NonUniformImage

from chip.utils.sinogram import Sinogram


def plot_radon_transform(radon_transform, fig=None, axis=None):
    """
    Plot the radon transform of an image.

    :param radon_transform: (n_angles, n_pixels)
    :param fig:
    :param axis:
    :return:
    """
    if fig is None:
        fig = plt.figure(figsize=(10, 10))
    if axis is None:
        axis = fig.add_subplot(111)

    dx, dy = 0.5 * 180.0 / radon_transform.shape[1], 0.5 / radon_transform.shape[1]

    axis.imshow(radon_transform.T, cmap=plt.cm.Greys_r,
                extent=(-dx, 180.0 + dx, -dy, radon_transform.shape[1] + dy), aspect='auto', interpolation='none')

    return fig, axis


def plot_reconstruction(tr, target, prior):
    with torch.no_grad():
        prediction = tr.get_img(circle_crop=True).detach().to('cpu')
        fig, axes = plt.subplots(3, 2, figsize=(10, 15))
        ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = axes
        ax1.set_title("Target")
        ax1.imshow(target, cmap="gray")
        ax2.set_title("Iterative Reconstruction")
        ax2.imshow(prediction, cmap="gray")
        ax3.set_title("Difference")
        ax3.imshow(torch.abs(prediction - target), cmap="gray")
        ax4.set_title("Prior")
        ax4.imshow(prior, cmap='gray')

        # show most uncertain pixels
        threshold = 0.2

        uncertainty = torch.logical_and(prediction > threshold, prediction < 1 - threshold)
        ax5.set_title("Uncertainty map")
        ax5.imshow(uncertainty, cmap='gray')

    return fig, axes

def plot_comparison(predicted, target, prior):
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    ((ax1, ax2), (ax3, ax4)) = axes
    ax1.set_title("Prior")
    ax1.imshow(prior, cmap="gray")
    ax2.set_title("Target")
    ax2.imshow(target, cmap="gray")
    ax3.set_title("Predicted")
    ax3.imshow(predicted, cmap='gray')
    ax4.set_title("Difference")
    ax4.imshow(torch.abs(predicted - target), cmap="gray")
    fig.subplots_adjust()
    fig.tight_layout()
    return fig, axes


def plot_sinogram(sinogram, fig=None, ax=None):
    # Plotting the original square image and its sinogram
    if fig is None or ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    if type(sinogram) == Sinogram:
        x = sinogram.angles.detach().cpu().numpy()
        y = np.arange(0, sinogram.shape[1])
        A = sinogram.sinogram.detach().cpu().numpy().T
    else:
        y = np.arange(0, sinogram.shape[1])
        x = np.arange(0, sinogram.shape[0])
        A = sinogram.detach().cpu().numpy().T

    # NonUniform Grid, linear x-axis
    im = NonUniformImage(ax, interpolation='nearest', extent=(0, 180, 0, len(y)), cmap='gray')
    im.set_data(x, y, A)
    ax.add_image(im)
    ax.set_xlim(0, 180)
    ax.set_ylim(0, len(y))
    ax.set_title('Sinogram')
    return fig, ax

def plot_tomogram_overview(tomogram, rows=5, cols=4):
    # overview plot
    num_layers, width, height = tomogram.shape
    layers = np.linspace(0, num_layers - 1, rows * cols, dtype=int)
    fig, axes = plt.subplots(rows, cols, figsize=(2. * cols, 2. * rows), squeeze=False)

    for layer, ax in zip(layers, itertools.chain.from_iterable(axes)):
        img = tomogram[layer]
        ax.imshow(img, cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Layer {layer}")

    fig.subplots_adjust()
    fig.tight_layout()
    return fig, axes