import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy import interpolate


def scipy_to_torch_sparse(A, device):
    """
    Convert a scipy sparse matrix to a torch sparse tensor.
    Resulting sparse format is CSR matrix.
    """
    if A.format != "csr":
        A = A.tocsr()

    rows = torch.LongTensor(A.indptr)
    cols = torch.LongTensor(A.indices)
    values = torch.FloatTensor(A.data)
    shape = torch.Size(A.shape)
    return torch.sparse_csr_tensor(
        rows, cols, values, shape, device=device)


def decompress_time_data(data, space_dim, output_dim, time_range):
    """
    Transform the time-dependent data to a point cloud format.

    From [x1, x2, ..., xd, y(t1), y(t2), ..., y(tm)] to
    [x1, x2, ..., xd, t1, y(t1)], 
    [x1, x2, ..., xd, t2, y(t2)], 
    ..., 
    [x1, x2, ..., xd, tm, y(tm)].

    Args:
        data: The time-dependent data.
        space_dim: The spatial dimension (i.e., d).
        output_dim: The output dimension (output_dim > 1 
            means that y is a vector-valued function).
        time_range: The time range (i.e., [t1, tm]).

    Returns: 
        Point cloud format data (a Numpy array).
    """
    t_len = (data.shape[1] - space_dim) // output_dim

    if t_len * output_dim != data.shape[1] - space_dim:
        # raise error augments the error message
        raise ValueError("Data shape is not multiple of output_dim")

    t = np.linspace(time_range[0], time_range[1], t_len)
    t, x0 = np.meshgrid(t, data[:, 0])
    list_x = [x0.reshape(-1)]
    for i in range(1, space_dim):
        list_x.append(np.stack([data[:, i] for _ in range(t_len)]).T.reshape(-1))
    list_x.append(t.reshape(-1))
    for i in range(output_dim):
        list_x.append(data[:, space_dim + i::output_dim].reshape(-1))
    return np.stack(list_x).T


def plot_heatmap(
    x, y, z, path, vmin=None, vmax=None,
    title="", xlabel="x", ylabel="y"
):
    '''
    Plot a heat map for the point cloud of a 
    2-dimensional function z = f(x, y).

    x, y, z: 1D Numpy arrays.
    '''
    plt.cla()
    plt.figure()
    xx = np.linspace(np.min(x), np.max(x))
    yy = np.linspace(np.min(y), np.max(y))
    xx, yy = np.meshgrid(xx, yy)
    yy = yy[::-1,:]

    vals = interpolate.griddata(np.array([x, y]).T, np.array(z), 
        (xx, yy), method='cubic')
    vals_0 = interpolate.griddata(np.array([x, y]).T, np.array(z), 
        (xx, yy), method='nearest')
    vals[np.isnan(vals)] = vals_0[np.isnan(vals)]

    if vmin is not None and vmax is not None:
        fig = plt.imshow(vals,
                extent=[np.min(x), np.max(x),np.min(y), np.max(y)],
                aspect="equal", interpolation="bicubic",
                vmin=vmin, vmax=vmax)
    elif vmin is not None:
        fig = plt.imshow(vals,
                extent=[np.min(x), np.max(x),np.min(y), np.max(y)],
                aspect="equal", interpolation="bicubic",
                vmin=vmin)
    else:
        fig = plt.imshow(vals,
                extent=[np.min(x), np.max(x),np.min(y), np.max(y)],
                aspect="equal", interpolation="bicubic")
    fig.axes.set_autoscale_on(False)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.colorbar()
    plt.savefig(path)
    plt.close()
