import scipy
from scipy import sparse
from scipy.spatial import Delaunay
from scipy.spatial import KDTree
from torch_geometric.data import Data
import h5py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import os
import pickle
import torch
import networkx as nx
import torch_geometric as tgeo


def get_mask(x, domain='unit_square'):
    mask = np.ones((x.shape[0], 1))
    if domain == 'unit_square':
        for i, xi in enumerate(x):
            if xi[0] == 0.0 or xi[0] == 1.0:
                mask[i] = 0.0
            if xi[1] == 0.0 or xi[1] == 1.0:
                mask[i] = 0.0
    elif domain == 'cutout':
        for i, xi in enumerate(x):
            if xi[0] == 0.0 or xi[0] == 1.0:
                mask[i] = 0.0
            if xi[1] == 0.0 or xi[1] == 1.0:
                mask[i] = 0.0
            if 0.25 <= xi[0] <= 0.75 and xi[1] == 0.5:
                mask[i] = 0.0
            if 0.0 <= xi[1] <= 0.5 and (xi[0] == 0.25 or xi[0] == 0.75):
                mask[i] = 0.0
    else:
        print("Wrong domain name.")
    return mask


def read_pickle(keys, path="./"):
    file_paths = [os.path.join(path, x+'.pkl') for x in keys]
    if not all([os.path.exists(x) for x in file_paths]):
        write_all_pickle(path)

    data_dict = {}
    for key in keys:
        with open(os.path.join(path, key+'.pkl'), 'rb') as f:
            data_dict[key] = pickle.load(f)
    return data_dict


def write_all_pickle(path):
    # include_bd = False
    with h5py.File(os.path.join(path, 'sim.h5')) as f:
        # keys = list(f.keys())
        # tempdata = {}
        for k in f.keys():
            # Get the data
            tempdata = np.array(f[k])
            savespot = os.path.join(path, k + ".pkl")
            with open(savespot, "wb") as output_file:
                pickle.dump(tempdata, output_file)
            if k == 'bd_points':
                # include_bd = True
                bd_pts = np.asarray(f['bd_points'])  # boundary points
                bcs_dicts = [{'boundary': [bd_pts[i], [0]]} for i in range(bd_pts.shape[0])]
                with open(os.path.join(path, 'bcs_dicts.pkl'), "wb") as output_file:
                    pickle.dump(bcs_dicts, output_file)

def neighbors_from_delaunay(tri):
    """Returns ndarray of shape (N, *) with indices of neigbors for each node.
    N is the number of nodes.
    """
    neighbors_tri = tri.vertex_neighbor_vertices
    neighbors = []
    for i in range(len(neighbors_tri[0])-1):
        curr_node_neighbors = []
        for j in range(neighbors_tri[0][i], neighbors_tri[0][i+1]):
            curr_node_neighbors.append(neighbors_tri[1][j])
        neighbors.append(curr_node_neighbors)
    return neighbors


def is_near(x, y, eps=1.0e-16):
    x = np.array(x)
    y = np.array(y)
    for yi in y:
        if np.linalg.norm(x - yi) < eps:
            return True
    return False


def get_edge_index(x):
    MAX_DIST = 400  # mesh specific!
    tri = Delaunay(x)
    neighbors = neighbors_from_delaunay(tri)
    edge_index = []
    for i, _ in enumerate(neighbors):
        for _, neighbor in enumerate(neighbors[i]):
            if i == neighbor:
                continue
            if np.linalg.norm(x[i] - x[neighbor]) > MAX_DIST:
                continue
            edge = [i, neighbor]
            edge_index.append(edge)
    edge_index = np.array(edge_index).T
    return edge_index

def get_knn_edge_index(x, k):
    edge_index = np.zeros((2, k*x.shape[0]))
    # batch_ei = np.zeros((2, k*x.shape[1]))

    tree = KDTree(x)
    _, ii = tree.query(x, k=list(range(2, k + 2)))
    for j in range(len(ii)):
        edge_index[:, j*k:j*k+k] = np.vstack((j*np.ones((1, k)), ii[j]))
    edge_index = edge_index.astype(int)
    return edge_index

def get_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 2))
    # !!! Note: This only works for diagional diffusion! !!!
    for i in range(len(rel_pos)):
        diff = np.diag(diffusion[edges[0, i]])
        weights[i, :] = (np.matmul(diff, rel_pos[i])/(np.linalg.multi_dot((rel_pos[i].T, np.linalg.inv(diff), rel_pos[i])))).T
    return np.linalg.norm(weights, axis=1)

def get_multi_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 2))
    # !!! Note: This only works for diagional diffusion! !!!
    for i in range(len(rel_pos)):
        diff = np.diag(diffusion[edges[0, i]])
        # weights[i, :] = (np.matmul(diff, rel_pos[i])/(np.linalg.multi_dot((rel_pos[i].T, np.linalg.inv(diff), rel_pos[i])))).T
        weights[i, :] = (np.matmul(diff, rel_pos[i]) / (np.matmul(rel_pos[i].T, rel_pos[i]))).T
    return np.abs(weights)

def get_HK_weights(diffusion, pos, edges):
    rel_pos = pos[edges[1]] - pos[edges[0]]
    weights = np.zeros((len(rel_pos), 1))
    # !!! Note: This only works for 2D diffusion only! !!!
    for ind, pos in enumerate(rel_pos):
        diff = diffusion[edges[0, ind]].reshape(2, 2)
        weights[ind] = 1 / (np.linalg.multi_dot([pos.T, np.linalg.inv(diff), pos]))
    return np.abs(weights)

def get_gft_eig_vec(edge_index):
    G = nx.Graph()
    G.add_nodes_from(np.unique(edge_index.reshape(-1)))
    G.add_edges_from(edge_index.T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    val, vecs = np.linalg.eigh(lap.toarray())
    return val, vecs

def get_weighted_gft_eig_vec(edge_index, weights):
    G = nx.Graph()
    G.add_nodes_from(np.unique(edge_index.reshape(-1)))
    G.add_weighted_edges_from(np.vstack((edge_index, weights)).T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    _, vecs = np.linalg.eigh(lap.toarray())
    return vecs

def get_weighted_sparse_gft_eig_vec(edge_index, weights, modes=100):
    G = nx.Graph()
    G.add_nodes_from(np.unique(edge_index.reshape(-1)))
    G.add_weighted_edges_from(np.vstack((edge_index, weights.T)).T)
    lap = nx.laplacianmatrix.laplacian_matrix(G)
    vals, vecs = sparse.linalg.eigsh(lap, k=modes, which='SM')
    return vals, vecs

def to_torch_sparse(x):
    """ converts numpy dense tensor x to torch sparse format """
    x = torch.Tensor(x)
    x_typename = torch.typename(x).split('.')[-1]
    sparse_tensortype = getattr(torch.sparse, x_typename)

    indices = torch.nonzero(x)
    if len(indices.shape) == 0:  # if all elements are zeros
        return sparse_tensortype(*x.shape)
    indices = indices.t()
    values = x[tuple(indices[i] for i in range(indices.shape[0]))]
    return sparse_tensortype(indices, values, [x.size()[0], x.size()[0]])

def generate_torchgeom_dataset(data, num_sims, bd_conditions, withDiff, fibers, edges, vecs, vals, modes):
    """Returns dataset that can be used to train our model.

    Args:
        data (dict): Data dictionary with keys t, x, u, bcs_dicts.
        bd_conditions (str): which bd conditions to use.
    Returns:
        dataset (list): Array of torchgeometric Data objects.
    """

    # n_sims = data['u'].shape[0]
    dataset = []

    for sim_ind in range(num_sims):
        print("{} / {}".format(sim_ind + 1, num_sims))

        bd_data_to_load = {}
        u = data['u'][sim_ind]
        # !!! This is a temporary fix! !!!
        if len(u.shape) < 3:
            u = np.expand_dims(u, axis=-1)

        if bd_conditions != 'none':
            bd_points = data['bcs_dicts'][sim_ind]
            bd_data_to_load['bcs_dicts'] = bd_points
            if bd_conditions == 'neumann':
                # TODO: This assumes u is in R^1
                a = get_neumann_boundary_matrix(data['x'][sim_ind], set(data['bcs_dicts'][sim_ind]['boundary'][0]))
                for bc_inds, field_inds in bd_points.values():
                    u[0, bc_inds, field_inds] = a.tocsr().dot(np.squeeze(u[0, :, field_inds]))
                bd_data_to_load['bd_constraint'] = a
        if withDiff == 'withDiff':
            bd_data_to_load['diffusion'] = torch.Tensor(data['diffusion'][sim_ind])

        if fibers == True:
            bd_data_to_load['fibers'] = torch.Tensor(data['fibers'][sim_ind])
        # edge_index = get_edge_index(data['x'][sim_ind])

        # full_eigx = torch.Tensor(vecsx[sim_ind][:, :modes])
        # full_eigy = torch.Tensor(vecsy[sim_ind][:, :modes])
        full_eig = torch.Tensor(vecs[sim_ind][:, :modes])
        full_vals = torch.Tensor(vals[sim_ind][:modes])

        tg_data = Data(
            edge_index=torch.Tensor(edges[sim_ind]).long(),
            pos=torch.Tensor(data['x'][sim_ind]),
            sim_ind=torch.tensor(sim_ind, dtype=torch.long),
            x_shape=data['x'][sim_ind].shape[0],
            x=torch.Tensor(u[0, :, :]),
            y=torch.Tensor(u).transpose(0, 1),
            t=torch.Tensor(data['t'][sim_ind]),
            redfor=full_eig.transpose(0, 1),
            redvals=full_vals,
            **bd_data_to_load
        )
        dataset.append(tg_data)

    return dataset


def get_parameters_count(model):
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return n_params


def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data, gain=1.66666667)
        torch.nn.init.zeros_(m.bias.data)


def get_warmup_scheduler(a):
    def scheduler(epoch):
        if epoch <= a:
            return epoch * 1.0 / a
        else:
            return 1.0
    return scheduler


def plot_graph(coords):
    tri = Delaunay(coords)
    plt.triplot(coords[:, 0], coords[:, 1], tri.simplices.copy())
    plt.plot(coords[:, 0], coords[:, 1], 'o')
    plt.hlines(0, 0, 1)
    plt.hlines(1, 0, 1)
    plt.vlines(0, 0, 1)
    plt.vlines(1, 0, 1)


def get_masked_triang(x, y, max_radius):
    triang = mtri.Triangulation(x, y)
    triangles = triang.triangles
    xtri = x[triangles] - np.roll(x[triangles], 1, axis=1)
    ytri = y[triangles] - np.roll(y[triangles], 1, axis=1)
    maxi = np.sqrt(xtri**2 + ytri**2).max(axis=1)
    triang.set_mask(maxi > max_radius)
    return triang


def plot_triang_grid(ax, coords, values):
    x = coords[:, 0]
    y = coords[:, 1]
    triang = get_masked_triang(x, y, max_radius=1.0)
    levels = np.linspace(0.0, 1.0, 11)
    im = ax.tricontourf(triang, values, levels=levels)  # norm=mpl.colors.Normalize(vmin=-0.5, vmax=1.5)
    # ax.triplot(triang, 'ko-', linewidth=0.1, ms=0.5)
    return im


def plot_grid(coords, save_path):
    x = coords[:, 0]
    y = coords[:, 1]
   
    triang = get_masked_triang(x, y, max_radius=1.0)

    plt.triplot(triang, 'ko-', linewidth=0.1, ms=0.5)
    plt.savefig(os.path.join(save_path, 'grid.png'))


def plot_fields(t, coords, fields, save_path=None):
    """
    Args:
        t (ndarray): Time points.
        coords (ndarray): Coordinates of nodes.
        fields (dict): keys - field names.
            values - ndarrays with shape (time, num_nodes, 1).
        save_path (str): Path where plot will be saved as save_path/field_name_time.png
    """
    num_fields = len(fields.keys())

    fig, ax = plt.subplots(1, num_fields, figsize=(6*num_fields, 6))
    if num_fields == 1:
        ax = [ax]
    else:
        ax = ax.reshape(-1)

    mappables = [
        plot_triang_grid(
            ax[i], coords,
            fields[list(fields.keys())[i]][0].squeeze()) for i in range(num_fields)]
    [fig.colorbar(im, ax=ax) for im in mappables]

    for j, tj in enumerate(t):
        for i, (key, field) in enumerate(fields.items()):
            ax[i].cla()
            mappables[i] = plot_triang_grid(ax[i], coords, field[j].squeeze())
            ax[i].set_aspect('equal')
            ax[i].set_title("Field {:s} at time {:.6f}".format(key, tj))
            
            if save_path is not None:
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                plt.savefig(os.path.join(save_path, 't={:.4f}.png'.format(tj)))


def concatenate_bcs_dicts(batch):
    concat_bcs_dict = {}
    shifts = batch['x_shape'].numpy()
    batch_bcs_dicts = batch['bcs_dicts']
    for k in batch_bcs_dicts.keys():  # all dicts have the same keys
        tmp_bc_inds = []
        tmp_field_inds = np.unique(batch_bcs_dicts[k][-1][0].numpy())  # all dicts have same field ids for the given key
        total_shift = 0
        for i, d in enumerate(batch_bcs_dicts[k][0]):
            tmp_bc_inds.extend(d + total_shift)
            total_shift += shifts[i]
        concat_bcs_dict[k] = [
            np.array(tmp_bc_inds, dtype=np.int64), 
            [[ind] for ind in tmp_field_inds]
        ]
    return concat_bcs_dict


def find_boundary(triangulation):
    boundary = set()
    for i in range(len(triangulation.neighbors)):
        for k in range(3):
            if triangulation.neighbors[i][k] == -1:
                nk1, nk2 = (k+1) % 3, (k+2) % 3
                boundary.add(triangulation.simplices[i][nk1])
                boundary.add(triangulation.simplices[i][nk2])
    return boundary


def get_neumann_normal_directions(points):
    # Create Delaunay triangulation
    tri = Delaunay(points)

    # Find points on the boundary
    bd_points = find_boundary(tri)

    # Find normal vector at each boundary point.
    #  The indices of neighboring vertices of vertex k are indices[indptr[k]:indptr[k+1]].
    indptr, indices = tri.vertex_neighbor_vertices
    normal_directions = []
    c = np.mean(points, 0)
    for this_bd_pt in bd_points:
        # neighbors of this boundary point, that are also on the boundary
        idx = indices[indptr[this_bd_pt]:indptr[this_bd_pt + 1]]
        idx = [i for i in idx if i in bd_points]
        # sort points by proximity
        idx.sort(key=lambda y: np.linalg.norm(points[this_bd_pt] - points[y]))

        # pick 2 closest neighbors
        nb1 = points[idx[0]]
        nb2 = points[idx[1]]
        v = points[this_bd_pt]

        # directions perp to which normal will be
        e1 = v - nb1
        e2 = v - nb2

        dir = c - v  # vector towards center to help with orientation

        norm_vs = []
        for e in [e1, e2]:
            dx = e[0]
            dy = e[1]
            n = np.array([-dy, dx])
            if np.dot(dir, n) > 0:  # if the resulting normal is pointing "inward"
                n = np.array([dy, -dx])
            n /= np.linalg.norm(n)
            norm_vs.append(n)
        n1, n2 = norm_vs

        w1 = np.linalg.norm(e1)
        w2 = np.linalg.norm(e2)
        n = (w1 * n1 + w2 * n2) / (w1 + w2)

        normal_directions.append(n)

    return normal_directions


def get_neumann_boundary_matrix(points, bd_points):
    k = 10
    tree = KDTree(points)
    dd, ii = tree.query(points[list(bd_points), :], k=list(range(k)))
    normal_directions = get_neumann_normal_directions(points)

    i_idx = []
    j_idx = []
    data = []
    for i, this_bd_pt in enumerate(bd_points):
        int_pts_idx = list(set(ii[i]) - bd_points)  # remove boundary points
        rel_positions = points[int_pts_idx, :] - points[this_bd_pt, :]
        rel_pos_norms = np.asarray([np.linalg.norm(rel_positions[j]) for j in range(rel_positions.shape[0])])
        v_mat = rel_positions / rel_pos_norms[:, np.newaxis]  # \bbm{V}_i
        v_tild_mat = np.matmul(np.linalg.inv(np.matmul(v_mat.T, v_mat)), v_mat.T)  # \tilde{\bbm{V}}_i
        rpe_tild_norm = np.matmul(normal_directions[i][np.newaxis, :], v_tild_mat)  # c_i
        wt_entries = rpe_tild_norm / rel_pos_norms
        bd_pt_data = np.squeeze(wt_entries / np.sum(wt_entries)).tolist()

        i_idx += [i for k in range(len(bd_pt_data))]
        j_idx += int_pts_idx
        data += bd_pt_data

    sp_matrix = sparse.coo_matrix((data, (i_idx, j_idx)), shape=(len(bd_points), points.shape[0]))

    return sp_matrix


def concatenate_neumann_bd_constraint_matrices(matrix_list, expected_shape):
    a_coo = sparse.hstack(matrix_list)
    a_tensor_coo = torch.sparse_coo_tensor(torch.tensor([a_coo.row, a_coo.col]), a_coo.data.astype(np.float32), expected_shape)

    return a_tensor_coo
