import torch
import os
import numpy as np
import imageio
import matplotlib.pyplot as plt


def sample_boundary(N, sdim=1, length=1.0, epsilon=1e-4, device="cpu"):
    """sample boundary points within a small range"""
    if sdim == 1:
        coords_left = (torch.rand(N // 2, 1, device=device) * 2 - 1) * epsilon - 1.
        coords_right = (torch.rand(N // 2, 1, device=device) * 2 - 1) * epsilon + 1.
        coords = torch.cat([coords_left, coords_right], dim=0) * length
    elif sdim == 2:
        raise NotImplementedError
    else:
        raise NotImplementedError
    return coords


def sample_uniform(resolution, sdim=1, length=1.0):
    """sample uniform grid points in space"""
    if sdim == 1:
        return sample_uniform_1D(resolution) * length
    elif sdim == 2:
        return sample_uniform_2D(resolution) * length
    else:
        raise NotImplementedError


def sample_random(N, sdim=1, length=1.0, device="cpu"):
    """sample uniformly random points in space"""
    if sdim == 1:
        return sample_random_1D(N, device=device) * length
    elif sdim == 2:
        return sample_random_2D(N, device=device) * length
    else:
        raise NotImplementedError


def sample_uniform_1D(resolution: int, normalize=True):
    coords = torch.linspace(0.5, resolution - 0.5, resolution).unsqueeze(-1)
    if normalize:
        coords = coords / resolution * 2 - 1
    return coords


def sample_random_1D(N: int, normalize=True, resolution: int=None, device="cpu"):
    coords = torch.rand(N, 1, device=device)
    if normalize:
        coords = coords * 2 - 1
    else:
        coords = coords * resolution
    return coords


def sample_uniform_2D(resolution: int, normalize=True):
    x = torch.linspace(0.5, resolution - 0.5, resolution)
    y = torch.linspace(0.5, resolution - 0.5, resolution)
    coords = torch.stack(torch.meshgrid(x, y, indexing='ij'), dim=-1)
    if normalize:
        coords = coords / resolution * 2 - 1
    return coords


def sample_random_2D(N: int, normalize=True, resolution: int=None):
    coords = torch.rand(N, 2)
    if normalize:
        coords = coords * 2 - 1
    else:
        coords = coords * resolution
    return coords


def draw_scalar_field2D(arr, vmin=None, vmax=None):
    fig, ax = plt.subplots(figsize=(3, 3))
    cax1 = ax.matshow(arr, vmin=vmin, vmax=vmax)
    fig.colorbar(cax1, ax=ax, fraction=0.046, pad=0.04)
    fig.tight_layout()
    return fig


def draw_scalar_field1D(x, y, y_max=None, y_gt=None):
    fig, ax = plt.subplots()
    if y_gt is not None:
        ax.plot(x, y_gt, color='red', alpha=0.2)
    ax.plot(x, y)
    if y_max is not None:
        ax.set_ylim(0, y_max)
    ax.set_aspect('equal')
    fig.tight_layout()
    return fig


def draw_vector_field2D(u, v, tag=None, to_array=False):
    assert u.shape == v.shape
    indices = np.indices(u.shape)
    fig, ax = plt.subplots(figsize=(3, 3))
    # scale = np.sqrt(np.sum(u ** 2 + v ** 2))
    # u = u / (scale + 1e-16)
    # v = v / (scale + 1e-16)
    ax.quiver(indices[0], indices[1], u, v, scale=u.shape[0], scale_units='width')
    if tag is not None:
        ax.text(-1, -1, tag, fontsize=12)
    if not to_array:
        return fig
    return figure2array(fig)


def save_figure(fig, save_path, close=True):
    plt.savefig(save_path, bbox_inches='tight')
    if close:
        plt.close(fig)


def figure2array(fig):
    fig.canvas.draw()       # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close()
    return image


def frames2gif(src_dir, save_path, fps=24):
    filenames = sorted([x for x in os.listdir(src_dir) if x.endswith('.png')])
    img_list = [imageio.imread(os.path.join(src_dir, name)) for name in filenames]
    imageio.mimsave(save_path, img_list, fps=fps)


def ensure_dir(path):
    """
    create path by first checking its existence,
    :param paths: path
    :return:
    """
    if not os.path.exists(path):
        os.makedirs(path)


def ensure_dirs(paths):
    """
    create paths by first checking their existence
    :param paths: list of path
    :return:
    """
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            ensure_dir(path)
    else:
        ensure_dir(paths)


def calculate_params_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size() # in bytes
    return param_size
