import scipy
from scipy import sparse
from scipy.spatial import Delaunay
from scipy.spatial import KDTree
from torch_geometric.data import Data
import h5py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import os
import pickle
import torch
import networkx as nx
import open3d as o3d
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add
import natsort


def get_mask(x, domain='unit_square'):
    mask = np.ones((x.shape[0], 1))
    if domain == 'unit_square':
        for i, xi in enumerate(x):
            if xi[0] == 0.0 or xi[0] == 1.0:
                mask[i] = 0.0
            if xi[1] == 0.0 or xi[1] == 1.0:
                mask[i] = 0.0
    elif domain == 'cutout':
        for i, xi in enumerate(x):
            if xi[0] == 0.0 or xi[0] == 1.0:
                mask[i] = 0.0
            if xi[1] == 0.0 or xi[1] == 1.0:
                mask[i] = 0.0
            if 0.25 <= xi[0] <= 0.75 and xi[1] == 0.5:
                mask[i] = 0.0
            if 0.0 <= xi[1] <= 0.5 and (xi[0] == 0.25 or xi[0] == 0.75):
                mask[i] = 0.0
    else:
        print("Wrong domain name.")
    return mask


def read_pickle(keys, path="./"):
    file_paths = [os.path.join(path, x+'.pkl') for x in keys]
    if not all([os.path.exists(x) for x in file_paths]):
        write_all_pickle(path)

    data_dict = {}
    for key in keys:
        with open(os.path.join(path, key+'.pkl'), 'rb') as f:
            data_dict[key] = pickle.load(f)
    return data_dict

def write_all_pickle(path):
    # include_bd = False
    with h5py.File(os.path.join(path, 'sim.h5')) as f:
        # keys = list(f.keys())
        # tempdata = {}
        temp_t = []
        temp_x = []
        temp_u = []
        temp_fibers = []
        temp_diffusion = []
        for k in natsort.os_sorted(f.keys()):
            # Get the data
            temp_t.append(np.asarray(f[k]['t']))
            temp_x.append(np.asarray(f[k]['x']))
            temp_u.append(np.asarray(f[k]['u']))
            temp_fibers.append(np.asarray(f[k]['fibers']))
            temp_diffusion.append(np.asarray(f[k]['diffusion']))
        with open(os.path.join(path, 't' + ".pkl"), "wb") as output_file:
                pickle.dump(temp_t, output_file)
        with open(os.path.join(path, 'x' + ".pkl"), "wb") as output_file:
                pickle.dump(temp_x, output_file)
        with open(os.path.join(path, 'u' + ".pkl"), "wb") as output_file:
                pickle.dump(temp_u, output_file)
        with open(os.path.join(path, 'fibers' + ".pkl"), "wb") as output_file:
                pickle.dump(temp_fibers, output_file)
        with open(os.path.join(path, 'diffusion' + ".pkl"), "wb") as output_file:
                pickle.dump(temp_diffusion, output_file)

def reduced_write_all_pickle(path, vars, data):
    # include_bd = False
    for k in vars:
        tempdata = data[k]
        savespot = os.path.join(path, k + ".pkl")
        with open(savespot, "wb") as output_file:
            pickle.dump(tempdata, output_file)


def neighbors_from_delaunay(tri):
    """Returns ndarray of shape (N, *) with indices of neigbors for each node.
    N is the number of nodes.
    """
    neighbors_tri = tri.vertex_neighbor_vertices
    neighbors = []
    for i in range(len(neighbors_tri[0])-1):
        curr_node_neighbors = []
        for j in range(neighbors_tri[0][i], neighbors_tri[0][i+1]):
            curr_node_neighbors.append(neighbors_tri[1][j])
        neighbors.append(curr_node_neighbors)
    return neighbors


def is_near(x, y, eps=1.0e-16):
    x = np.array(x)
    y = np.array(y)
    for yi in y:
        if np.linalg.norm(x - yi) < eps:
            return True
    return False


def get_edge_index(x):
    MAX_DIST = 400  # mesh specific!
    tri = Delaunay(x)
    neighbors = neighbors_from_delaunay(tri)
    edge_index = []
    for i, _ in enumerate(neighbors):
        for _, neighbor in enumerate(neighbors[i]):
            if i == neighbor:
                continue
            if np.linalg.norm(x[i] - x[neighbor]) > MAX_DIST:
                continue
            edge = [i, neighbor]
            edge_index.append(edge)
    edge_index = np.array(edge_index).T
    return edge_index

def get_knn_edge_index(x, k):
    edge_index = np.zeros((2, k*x.shape[0]))
    # batch_ei = np.zeros((2, k*x.shape[1]))

    tree = KDTree(x)
    _, ii = tree.query(x, k=list(range(2, k + 2)))
    for j in range(len(ii)):
        edge_index[:, j*k:j*k+k] = np.vstack((j*np.ones((1, k)), ii[j]))
    edge_index = edge_index.astype(int)
    return edge_index

def get_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 2))
    # !!! Note: This only works for diagional diffusion! !!!
    for i in range(len(rel_pos)):
        diff = np.diag(diffusion[edges[0, i]])
        weights[i, :] = (np.matmul(diff, rel_pos[i])/(np.linalg.multi_dot((rel_pos[i].T, np.linalg.inv(diff), rel_pos[i])))).T
    return np.linalg.norm(weights, axis=1)

def get_multi_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 2))
    # !!! Note: This only works for diagional diffusion! !!!
    for i in range(len(rel_pos)):
        diff = np.diag(diffusion[edges[0, i]])
        # weights[i, :] = (np.matmul(diff, rel_pos[i])/(np.linalg.multi_dot((rel_pos[i].T, np.linalg.inv(diff), rel_pos[i])))).T
        weights[i, :] = (np.matmul(diff, rel_pos[i]) / (np.matmul(rel_pos[i].T, rel_pos[i]))).T
    return np.abs(weights)


def get_HK_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 1))
    # !!! Note: This only works for 2D diffusion only! !!!
    for ind, pos in enumerate(rel_pos):
        diff = diffusion[edges[0, ind]].reshape(3, 3)
        weights[ind] = 1 / (np.linalg.multi_dot([pos[:-1].T, np.linalg.inv(diff[:-1, :-1]), pos[:-1]]))  ## For 2d only, otherwise det of diff is 0
    return np.abs(weights)

def get_gft_eig_vec(edge_index):
    # lap = tgeo.utils.get_laplacian(edge_index)
    G = nx.Graph()
    G.add_nodes_from(np.unique(edge_index.reshape(-1)))
    G.add_edges_from(edge_index.T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    val, vecs = np.linalg.eigh(lap.toarray())
    return val, vecs

def get_weighted_gft_eig_vec(edge_index, weights):
    # lap = tgeo.utils.get_laplacian(edge_index)
    G = nx.Graph()
    G.add_nodes_from(np.unique(edge_index.reshape(-1)))
    G.add_weighted_edges_from(np.vstack((edge_index, weights)).T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    _, vecs = np.linalg.eigh(lap.toarray())
    return vecs

def get_weighted_sparse_gft_eig_vec(edge_index, weights, modes=200):
    G = nx.Graph()
    nodes_id = np.unique(edge_index.reshape(-1))
    G.add_nodes_from(nodes_id)
    G.add_weighted_edges_from(np.vstack((edge_index, weights.T)).T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    adj = nx.adj_matrix(G)
    vals, vecs = sparse.linalg.eigsh(lap, k=modes, sigma=0, v0=np.ones(len(nodes_id))/len(nodes_id))
    return vals, vecs, adj, lap

def get_multi_weighted_gft_eig_vec(edge_index, weights):
    # lap = tgeo.utils.get_laplacian(edge_index)
    Gx = nx.Graph()
    Gx.add_nodes_from(np.unique(edge_index.reshape(-1)))
    Gy = Gx.copy()
    Gx.add_weighted_edges_from(np.vstack((edge_index, weights[0])).T)
    Gy.add_weighted_edges_from(np.vstack((edge_index, weights[1])).T)
    lapx = nx.laplacianmatrix.laplacian_matrix(Gx)
    lapy = nx.laplacianmatrix.laplacian_matrix(Gy)
    _, vecsx = np.linalg.eigh(lapx.toarray())
    _, vecsy = np.linalg.eigh(lapy.toarray())
    return vecsx, vecsy

def new_res(vars_to_read, data_path, new_res):
    ### Only works for 2D
    print('Changing Resolution of original data to '+str(new_res))
    lower_res_path = os.path.join(data_path, str(new_res))
    if os.path.exists(lower_res_path) == False:
        orig_data = read_pickle(vars_to_read, data_path)
        new_data = orig_data.copy()
        for idx, idat in enumerate(orig_data['x']):
            idat_shape = idat.shape
            pts = np.zeros((idat_shape[0], idat_shape[1] + 1))
            # pts[:, :-1] = idat
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(idat)
            pcd_ds, trace, pcd_list = o3d.geometry.PointCloud.voxel_down_sample_and_trace(pcd, new_res, pcd.get_min_bound(),
                                                                                           pcd.get_max_bound(), False)
            first_values = [sublist[0] for sublist in pcd_list]
            for k in vars_to_read:
                if k != 't':
                    new_data[k][idx] = orig_data[k][idx][first_values]
        os.makedirs(lower_res_path)
        reduced_write_all_pickle(lower_res_path, vars_to_read, new_data)

    else:
        new_data = read_pickle(vars_to_read, lower_res_path)

    return new_data


def to_torch_sparse(x):
    """ converts numpy dense tensor x to torch sparse format """
    x = torch.Tensor(x)
    x_typename = torch.typename(x).split('.')[-1]
    sparse_tensortype = getattr(torch.sparse, x_typename)

    indices = torch.nonzero(x)
    if len(indices.shape) == 0:  # if all elements are zeros
        return sparse_tensortype(*x.shape)
    indices = indices.t()
    values = x[tuple(indices[i] for i in range(indices.shape[0]))]
    return sparse_tensortype(indices, values, [x.size()[0], x.size()[0]])


class h1(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add') #  "Max" aggregation.

    def forward(self, x, edge_index, ew):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x, ew=ew)

    def message(self, x_i, x_j, ew):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

          # tmp has shape [E, 2 * in_channels]
        return x_j - x_i
def graph_grad(sig, edges):   #Need to make code work without edges!!!!
    temp = np.abs(sig[edges[1]] - sig[edges[0]])
    return scatter_add(torch.tensor(temp), torch.tensor(edges[0]), dim=0)

def weighted_graph_grad_old(sig, edges, edge_weights):
    temp = np.abs(sig[edges[1]]-sig[edges[0]]).squeeze() * np.sqrt(edge_weights)
    return scatter_add(torch.tensor(temp), torch.tensor(edges[0]), dim=0).unsqueeze(dim=-1)

def weighted_graph_grad(sig, edges, edge_weights):
    temp = (sig[edges[1]]-sig[edges[0]]).squeeze() * np.sqrt(edge_weights)
    return temp
def generate_torchgeom_dataset(data, num_sims, withDiff, fibers, edges, vecs, vals, weights, modes, grad):
    """Returns dataset that can be used to train our model.

    Args:
        data (dict): Data dictionary with keys t, x, u, bcs_dicts.
        bd_conditions (str): which bd conditions to use.
    Returns:
        dataset (list): Array of torchgeometric Data objects.
    """

    # n_sims = data['u'].shape[0]
    dataset = []

    for sim_ind in range(num_sims):
        print("{} / {}".format(sim_ind + 1, num_sims))

        bd_data_to_load = {}
        u = (data['u'][sim_ind] - (-85))/(20 - (-85))
        # !!! This is a temporary fix! !!!
        if len(u.shape) < 3:
            u = np.expand_dims(u, axis=-1)
        if withDiff == 'withDiff':
            bd_data_to_load['diffusion'] = torch.Tensor(1) #torch.Tensor(data['diffusion'][sim_ind])

        if fibers == True:
            bd_data_to_load['fibers'] = torch.Tensor(1) #torch.Tensor(data['fibers'][sim_ind])



        # edge_index = get_edge_index(data['x'][sim_ind])

        # full_eigx = torch.Tensor(vecsx[sim_ind][:, :modes])
        # full_eigy = torch.Tensor(vecsy[sim_ind][:, :modes])
        full_eig = torch.Tensor(vecs[sim_ind][:, :modes])
        full_vals = torch.Tensor(vals[sim_ind][:modes])

        # start = 10
        # end = 100

        if grad == True:
            grad_y = weighted_graph_grad(u[:, :10, :], edges[sim_ind], weights[sim_ind])  ##start:end:10

        tg_data = Data(
            edge_index=torch.Tensor(edges[sim_ind]).long(),
            edge_weights=torch.Tensor(weights[sim_ind]),
            pos=torch.Tensor(data['x'][sim_ind][:, :-1]/1000),
            sim_ind=torch.tensor(sim_ind, dtype=torch.long),
            x_shape=data['x'][sim_ind].shape[0],
            x=torch.Tensor(u[:, 0, :]), ##start
            y=torch.Tensor(u[:, :10, :]),  ##start:end:10
            grad_y=torch.Tensor(grad_y),
            # adj=torch.tensor(adj[sim_ind].toarray()),
            t=torch.Tensor(data['t'][sim_ind][:10]/1000),  ##[start:end:10]
            redfor=full_eig.transpose(0, 1),
            redvals=full_vals,
            **bd_data_to_load
        )
        dataset.append(tg_data)

    return dataset


def get_parameters_count(model):
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return n_params


def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data, gain=1.66666667)
        torch.nn.init.zeros_(m.bias.data)


def get_warmup_scheduler(a):
    def scheduler(epoch):
        if epoch <= a:
            return epoch * 1.0 / a
        else:
            return 1.0
    return scheduler


def plot_graph(coords):
    tri = Delaunay(coords)
    plt.triplot(coords[:, 0], coords[:, 1], tri.simplices.copy())
    plt.plot(coords[:, 0], coords[:, 1], 'o')
    plt.hlines(0, 0, 1)
    plt.hlines(1, 0, 1)
    plt.vlines(0, 0, 1)
    plt.vlines(1, 0, 1)


def get_masked_triang(x, y, max_radius):
    triang = mtri.Triangulation(x, y)
    triangles = triang.triangles
    xtri = x[triangles] - np.roll(x[triangles], 1, axis=1)
    ytri = y[triangles] - np.roll(y[triangles], 1, axis=1)
    maxi = np.sqrt(xtri**2 + ytri**2).max(axis=1)
    triang.set_mask(maxi > max_radius)
    return triang


def plot_triang_grid(ax, coords, values):
    x = coords[:, 0]
    y = coords[:, 1]
    triang = get_masked_triang(x, y, max_radius=1.0)
    levels = np.linspace(0.0, 1.0, 51)
    im = ax.tricontourf(triang, values, levels=levels)  # norm=mpl.colors.Normalize(vmin=-0.5, vmax=1.5)
    # ax.triplot(triang, 'ko-', linewidth=0.1, ms=0.5)
    return im


def plot_grid(coords, save_path):
    x = coords[:, 0]
    y = coords[:, 1]
   
    triang = get_masked_triang(x, y, max_radius=1.0)

    plt.triplot(triang, 'ko-', linewidth=0.1, ms=0.5)
    plt.savefig(os.path.join(save_path, 'grid.png'))


def plot_fields(t, coords, fields, save_path=None):
    """
    Args:
        t (ndarray): Time points.
        coords (ndarray): Coordinates of nodes.
        fields (dict): keys - field names.
            values - ndarrays with shape (time, num_nodes, 1).
        save_path (str): Path where plot will be saved as save_path/field_name_time.png
    """
    num_fields = len(fields.keys())

    fig, ax = plt.subplots(1, num_fields, figsize=(6*num_fields, 6))
    if num_fields == 1:
        ax = [ax]
    else:
        ax = ax.reshape(-1)

    mappables = [
        plot_triang_grid(
            ax[i], coords,
            fields[list(fields.keys())[i]][0].squeeze()) for i in range(num_fields)]
    [fig.colorbar(im, ax=ax) for im in mappables]

    for j, tj in enumerate(t):
        for i, (key, field) in enumerate(fields.items()):
            ax[i].cla()
            mappables[i] = plot_triang_grid(ax[i], coords, field[j].squeeze())
            ax[i].set_aspect('equal')
            ax[i].set_title("Field {:s} at time {:.6f}".format(key, tj))
            
            if save_path is not None:
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                plt.savefig(os.path.join(save_path, 't={:.4f}.png'.format(tj)))


def concatenate_bcs_dicts(batch):
    concat_bcs_dict = {}
    shifts = batch['x_shape'].numpy()
    batch_bcs_dicts = batch['bcs_dicts']
    for k in batch_bcs_dicts.keys():  # all dicts have the same keys
        tmp_bc_inds = []
        tmp_field_inds = np.unique(batch_bcs_dicts[k][-1][0].numpy())  # all dicts have same field ids for the given key
        total_shift = 0
        for i, d in enumerate(batch_bcs_dicts[k][0]):
            tmp_bc_inds.extend(d + total_shift)
            total_shift += shifts[i]
        concat_bcs_dict[k] = [
            np.array(tmp_bc_inds, dtype=np.int64), 
            [[ind] for ind in tmp_field_inds]
        ]
    return concat_bcs_dict


def find_boundary(triangulation):
    boundary = set()
    for i in range(len(triangulation.neighbors)):
        for k in range(3):
            if triangulation.neighbors[i][k] == -1:
                nk1, nk2 = (k+1) % 3, (k+2) % 3
                boundary.add(triangulation.simplices[i][nk1])
                boundary.add(triangulation.simplices[i][nk2])
    return boundary


def get_neumann_normal_directions(points):
    # Create Delaunay triangulation
    tri = Delaunay(points)

    # Find points on the boundary
    bd_points = find_boundary(tri)

    # Find normal vector at each boundary point.
    #  The indices of neighboring vertices of vertex k are indices[indptr[k]:indptr[k+1]].
    indptr, indices = tri.vertex_neighbor_vertices
    normal_directions = []
    c = np.mean(points, 0)
    for this_bd_pt in bd_points:
        # neighbors of this boundary point, that are also on the boundary
        idx = indices[indptr[this_bd_pt]:indptr[this_bd_pt + 1]]
        idx = [i for i in idx if i in bd_points]
        # sort points by proximity
        idx.sort(key=lambda y: np.linalg.norm(points[this_bd_pt] - points[y]))

        # pick 2 closest neighbors
        nb1 = points[idx[0]]
        nb2 = points[idx[1]]
        v = points[this_bd_pt]

        # directions perp to which normal will be
        e1 = v - nb1
        e2 = v - nb2

        dir = c - v  # vector towards center to help with orientation

        norm_vs = []
        for e in [e1, e2]:
            dx = e[0]
            dy = e[1]
            n = np.array([-dy, dx])
            if np.dot(dir, n) > 0:  # if the resulting normal is pointing "inward"
                n = np.array([dy, -dx])
            n /= np.linalg.norm(n)
            norm_vs.append(n)
        n1, n2 = norm_vs

        w1 = np.linalg.norm(e1)
        w2 = np.linalg.norm(e2)
        n = (w1 * n1 + w2 * n2) / (w1 + w2)

        normal_directions.append(n)

    return normal_directions


def get_neumann_boundary_matrix(points, bd_points):
    k = 10
    tree = KDTree(points)
    dd, ii = tree.query(points[list(bd_points), :], k=list(range(k)))
    normal_directions = get_neumann_normal_directions(points)

    i_idx = []
    j_idx = []
    data = []
    for i, this_bd_pt in enumerate(bd_points):
        int_pts_idx = list(set(ii[i]) - bd_points)  # remove boundary points
        rel_positions = points[int_pts_idx, :] - points[this_bd_pt, :]
        rel_pos_norms = np.asarray([np.linalg.norm(rel_positions[j]) for j in range(rel_positions.shape[0])])
        v_mat = rel_positions / rel_pos_norms[:, np.newaxis]  # \bbm{V}_i
        v_tild_mat = np.matmul(np.linalg.inv(np.matmul(v_mat.T, v_mat)), v_mat.T)  # \tilde{\bbm{V}}_i
        rpe_tild_norm = np.matmul(normal_directions[i][np.newaxis, :], v_tild_mat)  # c_i
        wt_entries = rpe_tild_norm / rel_pos_norms
        bd_pt_data = np.squeeze(wt_entries / np.sum(wt_entries)).tolist()

        i_idx += [i for k in range(len(bd_pt_data))]
        j_idx += int_pts_idx
        data += bd_pt_data

    sp_matrix = sparse.coo_matrix((data, (i_idx, j_idx)), shape=(len(bd_points), points.shape[0]))

    return sp_matrix


def concatenate_neumann_bd_constraint_matrices(matrix_list, expected_shape):
    a_coo = sparse.hstack(matrix_list)
    a_tensor_coo = torch.sparse_coo_tensor(torch.tensor([a_coo.row, a_coo.col]), a_coo.data.astype(np.float32), expected_shape)

    return a_tensor_coo
