# coding=utf-8
import numpy as np
import torch
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
from utils import data_loader, sparse_mx_to_torch_sparse_tensor
from normalization import fetch_normalization, gcn, get_diag
import os
import pickle as pkl

# forked from https://github.com/DropEdge/DropEdge

class Sampler:
    """Sampling the input graph data."""
    def __init__(self, dataset, data_path="data", task_type="full", with_diag=False):
        self.dataset = dataset
        self.data_path = data_path
        (self.adj,
         self.train_adj,
         self.features,
         self.train_features,
         self.labels,
         self.idx_train, 
         self.idx_val,
         self.idx_test, 
         self.degree,
         self.learning_type) = data_loader(dataset, data_path, "NoNorm", False, task_type)
        
        #convert some data to torch tensor ---- may be not the best practice here.
        self.features = torch.FloatTensor(self.features).float()
        self.train_features = torch.FloatTensor(self.train_features).float()
        # self.train_adj = self.train_adj.tocsr()

        self.labels_torch = torch.LongTensor(self.labels)
        self.idx_train_torch = torch.LongTensor(self.idx_train)
        self.idx_val_torch = torch.LongTensor(self.idx_val)
        self.idx_test_torch = torch.LongTensor(self.idx_test)

        # vertex_sampler cache
        # where return a tuple
        self.pos_train_idx = np.where(self.labels[self.idx_train] == 1)[0]
        self.neg_train_idx = np.where(self.labels[self.idx_train] == 0)[0]
        # self.pos_train_neighbor_idx = np.where
        

        self.nfeat = self.features.shape[1]
        self.nclass = int(self.labels.max().item() + 1)
        self.trainadj_cache = {}
        self.adj_cache = {}
        #print(type(self.train_adj))
        self.degree_p = None


    def _preprocess_adj(self, normalization, adj, cuda, kwargs=None):
        # if kwargs is not None:
        #     adj_normalizer = fetch_normalization(normalization, **kwargs)
        # else:
        #     adj_normalizer = fetch_normalization(normalization)
        adj_normalizer = fetch_normalization(normalization, kwargs)

        r_adj = adj_normalizer(adj)
        if isinstance(r_adj, np.ndarray):
            r_adj = torch.tensor(r_adj).float()
        else:
            r_adj = sparse_mx_to_torch_sparse_tensor(r_adj).float()
        if cuda:
            r_adj = r_adj.cuda()
        return r_adj

    def _preprocess_fea(self, fea, cuda):
        if cuda:
            return fea.cuda()
        else:
            return fea

    def MLP_sampler(self, normalization, cuda, kwargs):
        assert normalization == 'MLP', 'normalization should be MLP for MLP sampler'
        vp = get_diag(kwargs['dataset'], kwargs)
        if vp is None:
            print('must diagonalize')
            vp = eigsh(gcn(self.adj).toarray(), self.adj.shape[0])
            pkl.dump(vp, open(os.path.join(kwargs['datapath'], kwargs['dataset'] + 'D.p'), 'wb+'))
            print('done')
        assert vp is not None, 'dataset not diag in memory'
        n = kwargs['nfreq']
        n_inf = kwargs['nfreq_inf']
        if n == 0:
            r_adj = self._preprocess_adj(normalization, self.train_adj, cuda, kwargs)
            fea = self.train_features
            fea = self._preprocess_fea(fea, cuda)
            return r_adj, fea
        if n_inf == 0:
            vp = torch.FloatTensor(vp[1][:, -n:]).float()
        else:
            vp = torch.FloatTensor(vp[1][:, -n:-n_inf]).float()
        if kwargs['MLP_norm']:
            vp /= torch.sum(vp, dim=1, keepdim=True)

        r_adj = self._preprocess_adj(normalization, self.train_adj, cuda, kwargs)
        fea = torch.cat((self.train_features, vp), 1)
        fea = self._preprocess_fea(fea, cuda)
        return r_adj, fea

    def stub_sampler(self, normalization, cuda, kwargs):
        """
        The stub sampler. Return the original data. 
        """
        if normalization in self.trainadj_cache:
            r_adj = self.trainadj_cache[normalization]
        else:
            r_adj = self._preprocess_adj(normalization, self.train_adj, cuda, kwargs)
            self.trainadj_cache[normalization] = r_adj
        fea = self._preprocess_fea(self.train_features, cuda)
        return r_adj, fea

    def random_self_sampler(self, normalization, cuda, kwargs):
        assert normalization == 'gcn_new_self', "normalization should be gcn_new_self to randomize over gamma"
        gamma1 = kwargs['gamma1']
        gamma2 = kwargs['gamma2']
        gamma = np.random.uniform(gamma1, gamma2)
        kwargs['gamma'] = gamma

        r_adj = self._preprocess_adj(normalization, self.train_adj, cuda, kwargs)
        fea = self._preprocess_fea(self.train_features, cuda)
        return r_adj, fea


    def randomedge_sampler(self, percent, normalization, cuda, kwargs):
        """
        Randomly drop edge and preserve percent% edges.
        """
        "Opt here"
        if percent >= 1.0:
            return self.stub_sampler(normalization, cuda, kwargs)
        
        nnz = self.train_adj.nnz
        perm = np.random.permutation(nnz)
        preserve_nnz = int(nnz*percent)
        perm = perm[:preserve_nnz]
        r_adj = sp.coo_matrix((self.train_adj.data[perm],
                               (self.train_adj.row[perm],
                                self.train_adj.col[perm])),
                              shape=self.train_adj.shape)
        r_adj = self._preprocess_adj(normalization, r_adj, cuda)
        fea = self._preprocess_fea(self.train_features, cuda)
        return r_adj, fea

    def randomedge_2hop_sampler(self, percent, normalization, cuda, nhop=100):
        """
        Randomly drop edge and preserve percent% edges.
        """
        "Opt here"
        if percent >= 1.0 and nhop == 0:
            return self.stub_sampler(normalization, cuda)

        nnz = self.train_adj.nnz

        # two-hop
        for i in range(nhop):
            idx = np.random.randint(0, nnz)
            start_node = self.train_adj.col[idx]
            intermediary_node = self.train_adj.row[idx]
            possible_final_nodes = self.train_adj.row[self.train_adj.col == intermediary_node]
            if len(possible_final_nodes) > 0:
                final_node = possible_final_nodes[np.random.randint(len(possible_final_nodes))]
                # if final_node not in
                self.train_adj.row = np.append(self.train_adj.row, start_node)
                self.train_adj.col = np.append(self.train_adj.col, final_node)
                self.train_adj.data = np.append(self.train_adj.data, 1)

        perm = np.random.permutation(nnz)
        preserve_nnz = int(nnz * percent)
        perm = perm[:preserve_nnz]

        r_adj = sp.coo_matrix((self.train_adj.data[perm],
                               (self.train_adj.row[perm],
                                self.train_adj.col[perm])),
                              shape=self.train_adj.shape)

        r_adj = self._preprocess_adj(normalization, r_adj, cuda)
        fea = self._preprocess_fea(self.train_features, cuda)
        return r_adj, fea

    def vertex_sampler(self, percent, normalization, cuda):
        """
        Randomly drop vertexes.
        """
        if percent >= 1.0:
            return self.stub_sampler(normalization, cuda)
        self.learning_type = "inductive"
        pos_nnz = len(self.pos_train_idx)
        # neg_neighbor_nnz = 0.4 * percent
        neg_no_neighbor_nnz = len(self.neg_train_idx)
        pos_perm = np.random.permutation(pos_nnz)
        neg_perm = np.random.permutation(neg_no_neighbor_nnz)
        pos_perseve_nnz = int(0.9 * percent * pos_nnz)
        neg_perseve_nnz = int(0.1 * percent * neg_no_neighbor_nnz)
        # print(pos_perseve_nnz)
        # print(neg_perseve_nnz)
        pos_samples = self.pos_train_idx[pos_perm[:pos_perseve_nnz]]
        neg_samples = self.neg_train_idx[neg_perm[:neg_perseve_nnz]]
        all_samples = np.concatenate((pos_samples, neg_samples))
        r_adj = self.train_adj
        r_adj = r_adj[all_samples, :]
        r_adj = r_adj[:, all_samples]
        r_fea = self.train_features[all_samples, :]
        # print(r_fea.shape)
        # print(r_adj.shape)
        # print(len(all_samples))
        r_adj = self._preprocess_adj(normalization, r_adj, cuda)
        r_fea = self._preprocess_fea(r_fea, cuda)
        return r_adj, r_fea, all_samples

    def degree_sampler(self, percent, normalization, cuda):
        """
        Randomly drop edge wrt degree (high degree, low probility).
        """
        if percent >= 0:
            return self.stub_sampler(normalization, cuda)
        if self.degree_p is None:
            degree_adj = self.train_adj.multiply(self.degree)
            self.degree_p = degree_adj.data / (1.0 * np.sum(degree_adj.data))
        # degree_adj = degree_adj.multi degree_adj.sum()
        nnz = self.train_adj.nnz
        preserve_nnz = int(nnz * percent)
        perm = np.random.choice(nnz, preserve_nnz, replace=False, p=self.degree_p)
        r_adj = sp.coo_matrix((self.train_adj.data[perm],
                               (self.train_adj.row[perm],
                                self.train_adj.col[perm])),
                              shape=self.train_adj.shape)
        r_adj = self._preprocess_adj(normalization, r_adj, cuda)
        fea = self._preprocess_fea(self.train_features, cuda)
        return r_adj, fea

    def get_test_set(self, normalization, cuda, kwargs={}):
        """
        Return the test set. 
        """
        if normalization == 'MLP':
            return self.MLP_sampler(normalization, cuda, kwargs)

        if self.learning_type == "transductive":
            return self.stub_sampler(normalization, cuda, kwargs)
        else:
            if normalization in self.adj_cache:
                r_adj = self.adj_cache[normalization]
            else:
                r_adj = self._preprocess_adj(normalization, self.adj, cuda)
                self.adj_cache[normalization] = r_adj
            fea = self._preprocess_fea(self.features, cuda)
            return r_adj, fea

    def get_val_set(self, normalization, cuda):
        """
        Return the validataion set. Only for the inductive task.
        Currently behave the same with get_test_set
        """
        return self.get_test_set(normalization, cuda)

    def get_label_and_idxes(self, cuda):
        """
        Return all labels and indexes.
        """
        if cuda:
            return self.labels_torch.cuda(), self.idx_train_torch.cuda(), self.idx_val_torch.cuda(), self.idx_test_torch.cuda()
        return self.labels_torch, self.idx_train_torch, self.idx_val_torch, self.idx_test_torch
