import math
import logging
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.nn.pytorch import GATConv, SAGEConv, APPNPConv, GINConv
from dgl import function as fn
from dgl.nn.pytorch.softmax import edge_softmax
from dgl.nn.pytorch.utils import Identity
from dgl.utils import expand_as_pair
from torch.autograd import Variable
import time


def sparse_matrix_to_torch(X):
    coo = X.tocoo()
    indices = np.array([coo.row, coo.col])
    return torch.sparse.FloatTensor(
            torch.LongTensor(indices),
            torch.FloatTensor(coo.data),
            coo.shape)

class GNN_Models(object):
    def __init__(self, adj_matrix, features, labels, tvt_nids, cuda=0, hidden_size=64, n_layers=1, epochs=200, seed=-1, lr=1e-2, weight_decay=5e-4, dropout=0.5, log=True, name='debug', save_path='test_result/', activation='relu', model_name = 'gcn',save_model=0, feat_normalize=1, balance=0):

        self.lr = lr

        self.weight_decay = weight_decay
        self.n_epochs = epochs
        self.save_path = save_path
        self.model_name = model_name
        self.feat_normalize = feat_normalize
        if log:
            self.logger = self.get_logger(name)
        else:
            # disable logger if wanted
            self.logger = logging.getLogger()
        # config device (force device to cpu when cuda is not available)
        if not torch.cuda.is_available():
            cuda = -1
        self.device = torch.device(f'cuda:{cuda}' if cuda>=0 else 'cpu')
        # log all parameters to keep record
        all_vars = locals()
        self.log_parameters(all_vars)
        self.save_model = save_model
        # fix random seeds if needed
        if seed > 0:
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

        if activation == 'relu':
            act = nn.ReLU()
        elif activation == 'tanh':
            act = nn.Tanh()
        elif activation == 'sigmoid':
            act= nn.Sigmoid()
        elif activation == 'identity':
            act = nn.Identity()
        else:
            act = nn.ReLU()





        self.load_data(adj_matrix, features, labels, tvt_nids)
        # setup the model


        if model_name == 'gcn':
            self.model = GCN_model(self.features.size(1),
                                       hidden_size,
                                       self.n_class,
                                       n_layers,
                                       act,
                                       dropout)

        elif model_name == 'nog':
            self.model = NoG_model(self.features.size(1),
                                   hidden_size,
                                   self.n_class,
                                   n_layers,
                                   F.relu,
                                   dropout)

        elif model_name == 'mgcn':
            self.model = MGCN_model(self.features.size(1),
                                       hidden_size,
                                       self.n_class,
                                       n_layers,
                                       act,
                                       dropout,
                                       balance)




    def load_data(self, adj, features, labels, tvt_nids):
        """ preprocess data """
        # features (torch.FloatTensor)
        if sp.issparse(features):
            features = torch.FloatTensor(features.toarray())
        if isinstance(features, torch.FloatTensor):
            self.features = features
        else:
            self.features = torch.FloatTensor(features)
        if self.feat_normalize:
            self.features = F.normalize(self.features, p=1, dim=1)

        # labels (torch.LongTensor) and train/validation/test nids (np.ndarray)
        if isinstance(labels, np.ndarray):
            labels = torch.LongTensor(labels)
        self.labels = labels
        assert len(labels.size()) == 1
        self.train_nid = tvt_nids[0]
        self.val_nid = tvt_nids[1]
        self.test_nid = tvt_nids[2]
        # number of classes
        self.n_class = len(torch.unique(self.labels))
        # adj for training
        assert sp.issparse(adj)
        if not isinstance(adj, sp.coo_matrix):
            adj = sp.coo_matrix(adj)
        adj.setdiag(1)
        adj = sp.csr_matrix(adj)
        self.adj = adj
        self.G = DGLGraph(self.adj)
        self.G = dgl.to_bidirected(self.G).to(self.device)


            # normalization (D^{-1/2})
        degs = self.G.in_degrees().float()
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 0
        norm = norm.to(self.device)
        self.G.ndata['norm'] = norm.unsqueeze(1)








    def fit(self):
        """ train the model """
        # move data to device
        features = self.features.to(self.device)
        labels = self.labels.to(self.device)
        model = self.model.to(self.device)
        if self.model_name == 'appnpx':
            optimizer = torch.optim.Adam([{"params":model.mlp_layers[0].parameters(),
                                         'lr': self.lr,
                                         'weight_decay':self.weight_decay},
                                        {'params':model.mlp_layers[1].parameters(),
                                         'lr': self.lr}])
        else:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.lr,
                                         weight_decay=self.weight_decay)

        # keep record of the best validation accuracy for early stopping
        best_val_acc = 0.
        best_val_loss = 100
        # train model

        for epoch in range(self.n_epochs):
            model.train()

            logits = model(self.G, features)
            loss = F.nll_loss(logits[self.train_nid], labels[self.train_nid])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # validate with original graph (without dropout)
            self.model.eval()
            with torch.no_grad():

                logits_eval = model(self.G, features)
            val_acc = self.eval_node_cls(logits_eval[self.val_nid], labels[self.val_nid])
            val_loss = F.nll_loss(logits[self.val_nid], labels[self.val_nid])

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = self.eval_node_cls(logits_eval[self.test_nid], labels[self.test_nid])

                self.logger.info('Epoch [{:3}/{}]: loss {:.4f}, val acc {:.4f}, test acc {:.4f}'
                            .format(epoch+1, self.n_epochs, val_loss.item(), val_acc, test_acc))
#                 torch.save(model.state_dict(),self.save_path+'best_model.pt')
                if self.save_model:
                    torch.save(model,self.save_path+'best_model.pt')
            else:
                self.logger.info('Epoch [{:3}/{}]: loss {:.4f}, val acc {:.4f}'
                            .format(epoch+1, self.n_epochs, val_loss.item(), val_acc))
#         torch.save(model.state_dict(),self.save_path+'last_model.pt')
        if self.save_model:
            torch.save(model, self.save_path+'last_model.pt')
        # get final test result without early stop
        with torch.no_grad():
            logits_eval = model(self.G, features)
        test_acc_final = self.eval_node_cls(logits_eval[self.test_nid], labels[self.test_nid])
        # log both results
        self.logger.info('Final test acc with early stop: {:.4f}, without early stop: {:.4f}'
                    .format(test_acc, test_acc_final))
    
        return test_acc

    def log_parameters(self, all_vars):
        """ log all variables in the input dict excluding the following ones """
        del all_vars['self']
        del all_vars['adj_matrix']
        del all_vars['features']
        del all_vars['labels']
        del all_vars['tvt_nids']
        self.logger.info(f'Parameters: {all_vars}')

    @staticmethod
    def eval_node_cls(nc_logits, labels):
        """ evaluate node classification results """
        preds = torch.argmax(nc_logits, dim=1)
        correct = torch.sum(preds == labels)
        acc = correct.item() / len(labels)
        return acc

    @staticmethod
    def get_logger(name):
        """ create a nice logger """
        logger = logging.getLogger(name)
        # clear handlers if they were created in other runs
        if (logger.hasHandlers()):
            logger.handlers.clear()
        logger.setLevel(logging.DEBUG)
        # create formatter
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        # create console handler add add to logger
        ch = logging.StreamHandler()
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(formatter)
        logger.addHandler(ch)
        # create file handler add add to logger when name is not None
        if name is not None:
            fh = logging.FileHandler(f't-{name}.log')
            fh.setFormatter(formatter)
            fh.setLevel(logging.DEBUG)
            logger.addHandler(fh)
        return logger







class GCN_model(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_hop,
                 activation,
                 dropout):
        super(GCN_model, self).__init__()
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(GCNLayer(in_feats, n_hidden, activation, 0.))
        # hidden layers
        for i in range(n_hop - 1):
            self.layers.append(GCNLayer(n_hidden, n_hidden, activation, dropout))
        # output layer
        self.layers.append(GCNLayer(n_hidden, n_classes, None, dropout))

    def get_embs(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        return F.normalize(h, p=2,dim = 1)

    def forward(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        return F.log_softmax(h, dim=1)









class GCNLayer(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, g, h):
        if self.dropout:
            h = self.dropout(h)
        h = torch.mm(h, self.weight)
        # normalization by square root of src degree
        h = h * g.ndata['norm']
        g.ndata['h'] = h
        g.update_all(fn.copy_src(src='h', out='m'),
                     fn.sum(msg='m', out='h'))
        h = g.ndata.pop('h')
        # normalization by square root of dst degree
        h = h * g.ndata['norm']
        # bias
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return h






class NoG_model(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(NoG_model, self).__init__()
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.n_layers = n_layers
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(NoGLayer(in_feats, n_hidden, None, 0.))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(NoGLayer(n_hidden, n_hidden, None, dropout))
        # output layer
        self.layers.append(nn.Linear(n_hidden, n_classes))
        

    def forward(self, g, features):
        h = features
        for l in range(self.n_layers):
            h = self.layers[l](g,h)
            h = self.activation(h)
        h = self.dropout(h)
        h = self.layers[-1](h)

        return F.log_softmax(h, dim=1)


class NoGLayer(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(NoGLayer, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, g, h):
        if self.dropout:
            h = self.dropout(h)
        h = torch.mm(h, self.weight)
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return h















class MGCN_model(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_hop,
                 activation,
                 dropout,
                 balance):
        super(MGCN_model, self).__init__()
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(GCNLayer(in_feats, n_hidden, activation, 0.))
        # hidden layers
        for i in range(n_hop - 1):
            self.layers.append(GCNLayer(n_hidden, n_hidden, activation, dropout))
        # output layer
        self.layers.append(GCNLayer(n_hidden, n_classes, None, dropout))
        self.lin1 = nn.Linear(in_feats, n_hidden)
        self.lin2 = nn.Linear(n_hidden, n_classes)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)
        self.balance = balance 
        

    def get_embs(self, g, features):
        h = features
        for layer in self.layers:
            h = layer(g, h)
        return F.normalize(h, p=2,dim = 1)

    def forward(self, g, features):
        h = features
        s = self.activation(self.lin1(features))
        s = self.dropout(s)
        s = self.lin2(s)


        for layer in self.layers:
            h = layer(g, h)
        h = self.balance*h + (1-self.balance)*s
        return F.log_softmax(h, dim=1)


