import numpy as np
import torch
import torch.nn.functional as F
import scipy.io as sio
import os

def one_hot_encode(x, n_classes):
    """
    One hot encode a list of sample labels. Return a one-hot encoded vector for each label.
    : x: List of sample Labels
    : return: Numpy array of one-hot encoded labels
     """
    return np.eye(n_classes)[x]


def load_network_data(name):
    net = sio.loadmat('./data/' + name + '.mat')
    X, A, Y = net['attrb'], net['network'], net['group']
    if name in ['cs', 'photo']:
        Y = Y.flatten()
        Y = one_hot_encode(Y, Y.max() + 1).astype(np.int32)
    return A, X, Y


def random_planetoid_splits(num_classes, y, train_num, seed):
    np.random.seed(seed)
    indices = []

    for i in range(num_classes):
        index = (y == i).nonzero().view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:train_num] for i in indices], dim=0)

    rest_index = torch.cat([i[train_num:] for i in indices], dim=0)
    rest_index = rest_index[torch.randperm(rest_index.size(0))]

    val_index = rest_index[:500]
    test_index = rest_index[500:1500]

    return train_index, val_index, test_index


def get_train_data(labels, tr_num, val_num, seed):
    np.random.seed(seed)
    labels_vec = labels.argmax(1)
    labels_num = labels_vec.max() + 1

    idx_train = []
    idx_val = []
    for label_idx in range(labels_num):
        pos0 = np.argwhere(labels_vec == label_idx).flatten()
        pos0 = np.random.permutation(pos0)
        idx_train.append(pos0[0:tr_num]) 
        idx_val.append(pos0[tr_num:val_num + tr_num])

    idx_train = np.array(idx_train).flatten()
    idx_val = np.array(idx_val).flatten()
    idx_test = np.setdiff1d(range(labels.shape[0]), np.union1d(idx_train, idx_val))

    idx_train = torch.LongTensor(np.random.permutation(idx_train))
    idx_val = torch.LongTensor(np.random.permutation(idx_val))
    idx_test = torch.LongTensor(np.random.permutation(idx_test))

    return idx_train, idx_val, idx_test


def get_sub_emb(h, adj):

    q = h[0]
    k = h[1]
    v = h[2]

    adj = adj + torch.eye(adj.shape[0]).cuda(0)

    k_m = torch.mm((adj / adj.sum(1)).T, k)

    a = torch.mm(q, k_m.T).mul(adj)

    a = F.normalize(a, dim=1)

    z = torch.mm(a, v)

    return z


def get_multi_sub_emb(h, adj):

    q = h[0]
    k = h[1]
    v = h[2]

    adj = adj + torch.eye(adj.shape[0]).cuda(0)

    k_m = torch.mm((adj / adj.sum(1)).T, k)
    q_m = torch.mm((adj / adj.sum(1)).T, q)

    a1 = torch.mm(q, k_m.T) / torch.sqrt(torch.tensor(q.shape[1]))
    a1 = a1.mul(adj)

    z1 = torch.mm(a1, v).unsqueeze(0)

    a2 = torch.mm(k, q_m.T) / torch.sqrt(torch.tensor(q.shape[1]))
    a2 = a2.mul(adj)

    z2 = torch.mm(a2, v).unsqueeze(0)

    z = torch.cat((z1, z2))

    return z 
