from sklearn.metrics import zero_one_loss
import torch
import numpy as np
import matplotlib.pyplot as plt

import open3d as o3d

from helpers import figure2array
from torchgp import sample_surface, sample_volume

def sample_mesh(V, F, N, distrib = None):
    if F.shape[1] == 3:
        coords = sample_surface(V, F, N, distrib)
    elif F.shape[1] == 4:
        coords = sample_volume(V, F, N, distrib)
    return coords
    

def sample_boundary(N, sdim, epsilon=1e-4, device='cpu'):
    """sample boundary points within a small range"""
    if sdim == 1:
        coords_left = (torch.rand(N // 2, 1, device=device) * 2 - 1) * epsilon - 1.
        coords_right = (torch.rand(N // 2, 1, device=device) * 2 - 1) * epsilon + 1.
        coords = torch.cat([coords_left, coords_right], dim=0)
    elif sdim == 2:
        raise NotImplementedError
    else:
        raise NotImplementedError
    return coords


def sample_random(N, sdim, device='cpu'):
    """sample uniformly random points in space"""
    if sdim == 1:
        return sample_random_1D(N, device=device)
    elif sdim == 2:
        return sample_random_2D(N, device=device)
    elif sdim == 3:
        return sample_random_3D(N, device=device)
    else:
        raise NotImplementedError


def sample_uniform(N, sdim, device='cpu'):
    """sample uniform points in space"""
    if sdim == 1:
        return sample_uniform_1D(N, device=device)
    elif sdim == 2:
        return sample_uniform_2D(N, device=device)
    elif sdim == 3:
        return sample_uniform_3D(N, device=device)
    else:
        raise NotImplementedError


def sample_uniform_1D(resolution: int, normalize=True, device='cpu'):
    coords = torch.linspace(0.5, resolution - 0.5, resolution, device=device).unsqueeze(-1)
    if normalize:
        coords = coords / resolution * 2 - 1
    return coords


def sample_random_1D(N: int, normalize=True, resolution: int=None, device='cpu'):
    coords = torch.rand(N, 1, device=device)
    if normalize:
        coords = coords * 2 - 1
    else:
        coords = coords * resolution
    return coords


def sample_uniform_2D(resolution: int, normalize=True, device='cpu'):
    x = torch.linspace(0.5, resolution - 0.5, resolution, device=device)
    y = torch.linspace(0.5, resolution - 0.5, resolution, device=device)
    coords = torch.stack(torch.meshgrid(x, y, indexing='ij'), dim=-1)
    if normalize:
        coords = coords / resolution * 2 - 1
    coords = coords.reshape(resolution**2, 2)
    return coords


def sample_random_2D(N: int, normalize=True, resolution: int=None, device='cpu'):
    coords = torch.rand(N, 2, device=device)
    if normalize:
        coords = coords * 2 - 1
    else:
        coords = coords * resolution
    return coords


def sample_uniform_3D(resolution: int, normalize=True, device='cpu'):
    x = torch.linspace(0.5, resolution - 0.5, resolution, device=device)
    y = torch.linspace(0.5, resolution - 0.5, resolution, device=device)
    z = torch.linspace(0.5, resolution - 0.5, resolution, device=device)
    coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1)
    if normalize:
        coords = coords / resolution * 2 - 1
    coords = coords.reshape(resolution**3, 3)
    return coords


def sample_random_3D(N: int, normalize=True, resolution: int=None, device='cpu'):
    coords = torch.rand(N, 3, device=device)
    if normalize:
        coords = coords * 2 - 1
    else:
        coords = coords * resolution
    return coords


def write_pointcloud_to_file(filename, points, color=None):
    pcd = o3d.geometry.PointCloud()
    if points.shape[1] == 3:
        pcd.points = o3d.utility.Vector3dVector(points)
    else:
        pcd.points = o3d.utility.Vector3dVector(np.hstack([points, np.zeros((points.shape[0], 1))]))
    if color is not None:
        pcd.colors = o3d.utility.Vector3dVector(color)
    o3d.io.write_point_cloud(filename, pcd)


def draw_deformation_field3D(arr, vmin=None, vmax=None, color=None, plane_height=None, sphere_center=None, sphere_radius=None, hide_axis = False):
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    if color is None:
        ax.scatter(arr[:,0], arr[:,1], arr[:,2], c = arr[:,0]+arr[:,1]+arr[:,2], cmap='cool', s=0.2)
    else:
        ax.scatter(arr[:,0], arr[:,1], arr[:,2], c = color, cmap='cool', s=0.2)
    ax.set_xlim3d(-3, 3)
    ax.set_ylim3d(-3, 3)
    ax.set_zlim3d(-3, 3)
    if hide_axis:
        ax.set_axis_off()
    # ax.view_init(0, 0)
    if plane_height is not None:
        X, Y = np.meshgrid(np.arange(-3, 4), np.arange(-3, 4))
        Z = plane_height * np.ones_like(X)
        ax.plot_surface(X, Y, Z, alpha=0.2, color='green')  # the horizontal plane
    if sphere_radius is not None and sphere_center is not None:
        u = np.linspace(0, 2 * np.pi, 10)
        v = np.linspace(0, np.pi, 10)
        x = np.outer(np.cos(u), np.sin(v)) * sphere_radius
        y = np.outer(np.sin(u), np.sin(v)) * sphere_radius
        z = np.outer(np.ones(np.size(u)), np.cos(v)) * sphere_radius
        ax.plot_surface(x, y, z, linewidth=0.0, alpha=0.1, color='blue')

    return fig


def draw_deformation_field2D(arr, vmin=None, vmax=None, color=None, plane_height=None, circle_center=None, circle_radius=None, scatter_radius=0.2, hide_axis = False, cmap = 'cool'):
    fig, ax = plt.subplots(figsize=(3, 3))
    if color is None:
        ax.scatter(arr[:,0], arr[:,1], c = arr[:,0]+arr[:,1], cmap=cmap, s=scatter_radius)
    else:
        ax.scatter(arr[:,0], arr[:,1], c = color, cmap=cmap, s=scatter_radius)
    # plt.xlim(-3.5, 2.5)
    # plt.ylim(-3.5, 2.5)
    # plt.xlim(-2.5, 4.5)
    # plt.ylim(-3.5, 3.5)
    plt.xlim(-2.5, 2.5)
    plt.ylim(-3.5, 1.5)
    if plane_height is not None:
        plt.axhline(y=plane_height, color='green', linestyle='-', alpha=0.5)
    if circle_radius is not None and circle_center is not None:
        cir = plt.Circle(circle_center, circle_radius, color='orange',fill=False, linewidth=0.5)
        ax.set_aspect('equal', adjustable='datalim')
        ax.add_patch(cir)
    if hide_axis:
        ax.set_axis_off()
    else:
        plt.axhline(y=1.0, color='orange', linestyle='-', alpha=0.5)
        plt.axhline(y=-1.0, color='orange', linestyle='-', alpha=0.5)
    return fig


def draw_scalar_field2D(arr, vmin=None, vmax=None):
    fig, ax = plt.subplots(figsize=(3, 3))
    cax1 = ax.matshow(arr, vmin=vmin, vmax=vmax)
    fig.colorbar(cax1, ax=ax, fraction=0.046, pad=0.04)
    fig.tight_layout()
    return fig


def draw_vector_field2D(u, v, tag=None, to_array=False):
    assert u.shape == v.shape
    indices = np.indices(u.shape)
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.quiver(indices[0], indices[1], u, v, scale=u.shape[0], scale_units='width')
    if tag is not None:
        ax.text(-1, -1, tag, fontsize=12)
    if not to_array:
        return fig
    return figure2array(fig)


if __name__ == '__main__':
    '''Testing'''
    sample_uniform_2D(10)
    sample_uniform_3D(10)
     