import matplotlib.pyplot as plt
from scipy import interpolate
import numpy as np

def decompress_time_data(data, space_dim, output_dim, time_range):
    """
    Transform the time-dependent data to a point cloud format.

    From [x1, x2, ..., xd, y(t1), y(t2), ..., y(tm)] to
    [x1, x2, ..., xd, t1, y(t1)], 
    [x1, x2, ..., xd, t2, y(t2)], 
    ..., 
    [x1, x2, ..., xd, tm, y(tm)].

    Args:
        data: The time-dependent data.
        space_dim: The spatial dimension (i.e., d).
        output_dim: The output dimension (output_dim > 1 
            means that y is a vector-valued function).
        time_range: The time range (i.e., [t1, tm]).

    Returns: 
        Point cloud format data (a Numpy array).
    """
    t_len = (data.shape[1] - space_dim) // output_dim

    if t_len * output_dim != data.shape[1] - space_dim:
        # raise error augments the error message
        raise ValueError("Data shape is not multiple of output_dim")

    t = np.linspace(time_range[0], time_range[1], t_len)
    t, x0 = np.meshgrid(t, data[:, 0])
    list_x = [x0.reshape(-1)]
    for i in range(1, space_dim):
        list_x.append(np.stack([data[:, i] for _ in range(t_len)]).T.reshape(-1))
    list_x.append(t.reshape(-1))
    for i in range(output_dim):
        list_x.append(data[:, space_dim + i::output_dim].reshape(-1))
    return np.stack(list_x).T

def plot_lines(data, path, xlabel="", ylabel="", labels=None, xlog=False, ylog=False, title='', sort_=False):
    plt.cla()
    plt.figure()
    if labels is None:
        labels = ["" for _ in range(len(data) - 1)]
    for i in range(1, len(data)):
        if sort_:
            x = np.array(data[0])
            y = np.array(data[i])
            sorted_indices = np.argsort(x)
            sorted_x = x[sorted_indices]
            sorted_y = y[sorted_indices]
            plt.plot(sorted_x, sorted_y, label=labels[i - 1])
        else:
            plt.plot(data[0], data[i], label=labels[i - 1])
    plt.legend()
    if ylog: plt.yscale('log')
    if xlog: plt.xscale('log')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.savefig(path)
    plt.close()

def plot_heatmap(x, y, z, path=None, vmin=None, vmax=None, num=100, title='', xlabel='x', ylabel='y', show=False, pde=None):
    '''
    Plot heat map for a 3-dimension data
    '''
    plt.cla()
    plt.figure()
    xx = np.linspace(np.min(x), np.max(x), num)
    yy = np.linspace(np.min(y), np.max(y), num)
    xx, yy = np.meshgrid(xx, yy)

    vals = interpolate.griddata(np.array([x, y]).T, np.array(z), (xx, yy), method='cubic')
    vals_0 = interpolate.griddata(np.array([x, y]).T, np.array(z), (xx, yy), method='nearest')
    vals[np.isnan(vals)] = vals_0[np.isnan(vals)]

    vals = vals[::-1, :]  # reverse y coordinate: for imshow, (0,0) show at top left.

    fig = plt.imshow(vals, extent=[np.min(x), np.max(x), np.min(y), np.max(y)], aspect='auto', interpolation='bicubic', vmin=vmin, vmax=vmax)
    fig.axes.set_autoscale_on(False)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.colorbar()
    if path:
        plt.savefig(path)
    if show:
        plt.show()
    plt.close()