import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F

class Encoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, input_dim,
            output_dim, adj_lists, aggregator, device,
            base_model=None, gcn=False,
            relu=False, leaky_relu=False,
            dropout=False):
        super(Encoder, self).__init__()

        self.features = features
        self.input_dim = input_dim
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model

        self.gcn = gcn
        self.output_dim = output_dim
        self.device = device

        if self.gcn:
            self.w = nn.Parameter(
                    torch.empty(input_dim, output_dim).to(self.device))
            self.b = nn.Parameter(torch.zeros(output_dim).to(self.device))

        else:
            self.w = nn.Parameter(
                    torch.empty(2 * input_dim, output_dim).to(self.device))
            self.b = nn.Parameter(torch.zeros(output_dim).to(self.device))

        init.xavier_uniform_(self.w)

        #
        if relu:
            self.relu = nn.ReLU()
        else:
            self.relu = None

        # giving preference for relu over leaky relu
        if leaky_relu and not relu:
            self.relu = nn.LeakyReLU(0.2)

        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        # TODO: have a normalize option

    def forward(self, nodes):
        """
        Generates embeddings for a batch of nodes.

        nodes     -- list of nodes
        """
        nodes_list = [int(node) for node in nodes]
        neigh_feats = self.aggregator.forward(nodes_list, [self.adj_lists[node] for node in nodes_list])
        if not self.gcn:
            if type(nodes) == list:
                nodes = torch.tensor(nodes).to(self.device)
            self_feats = self.features(nodes)
            combined = torch.cat([self_feats, neigh_feats], dim=1)
        else:
            combined = neigh_feats

        output = torch.mm(combined, self.w) + self.b

        if self.dropout is not None:
            output = self.dropout(output)

        if self.relu is not None:
            output = self.relu(output)

        output = F.normalize(output)

        return output


class AttnEncoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, adj_lists, aggregator,
                 device,
            base_model=None, gcn=False,
            relu=False, leaky_relu=False,
            dropout=False):
        super(AttnEncoder, self).__init__()

        self.features = features
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model

        self.gcn = gcn
        self.device = device

        #
        if relu:
            self.relu = nn.ReLU()
        else:
            self.relu = None

        # giving preference for relu over leaky relu
        if leaky_relu and not relu:
            self.relu = nn.LeakyReLU(0.2)

        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

    def forward(self, nodes):
        """
        Generates embeddings for a batch of nodes.

        nodes     -- list of nodes
        """
        nodes_list = [int(node) for node in nodes]
        output = self.aggregator.forward(nodes_list, [self.adj_lists[node] for node in nodes_list])

        if self.dropout is not None:
            output = self.dropout(output)

        if self.relu is not None:
            output = self.relu(output)

        return F.normalize(output)


class RCGNEncoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, input_dim,
            output_dim, adj_lists, aggregator, device,
            base_model=None, gcn=False,
            relu=False, leaky_relu=False,
            dropout=False, add_weight=True):
        super(RCGNEncoder, self).__init__()

        self.features = features
        self.input_dim = input_dim
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model

        self.output_dim = output_dim
        self.device = device
        self.add_weight = add_weight
        if add_weight:
            self.w = nn.Parameter(
                    torch.empty(input_dim, output_dim).to(self.device))
            self.b = nn.Parameter(torch.zeros(output_dim).to(self.device))

            init.xavier_uniform_(self.w)

        #
        if relu:
            self.relu = nn.ReLU()
        else:
            self.relu = None

        # giving preference for relu over leaky relu
        if leaky_relu and not relu:
            self.relu = nn.LeakyReLU(0.2)

        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        # TODO: have a normalize option

    def forward(self, nodes):
        """
        Generates embeddings for a batch of nodes.

        nodes     -- list of nodes
        """
        nodes_list = [int(node) for node in nodes]
        neigh_feats = self.aggregator.forward(nodes_list, [self.adj_lists[node] for node in nodes_list])

        # Wh_{v}^(l-1) + a_{v}^{l}
        if type(nodes) == list:
            nodes = torch.tensor(nodes).to(self.device)
        self_feats = self.features(nodes)
        if self.add_weight:
            output = torch.mm(self_feats, self.w) + neigh_feats
        else:
            output = neigh_feats

        if self.dropout is not None:
            output = self.dropout(output)

        if self.relu is not None:
            output = self.relu(output)

        return F.normalize(output)


class DynamicEncoder(nn.Module):
    """
    Encodes a node's using 'convolutional' GraphSage approach
    """
    def __init__(self, features, input_dim,
            output_dim, adj_lists, aggregator, device,
            base_model=None, gcn=False,
            relu=False, leaky_relu=False,
            dropout=False):
        super(DynamicEncoder, self).__init__()

        self.features = features
        self.input_dim = input_dim
        self.adj_lists = adj_lists
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model

        self.gcn = gcn
        self.output_dim = output_dim
        self.device = device

        if self.gcn:
            self.w = nn.Parameter(
                    torch.empty(input_dim, output_dim).to(self.device))
            self.b = nn.Parameter(torch.zeros(output_dim).to(self.device))
        else:
            self.w = nn.Parameter(
                    torch.empty(2 * input_dim, output_dim).to(self.device))
            self.b = nn.Parameter(torch.zeros(output_dim).to(self.device))

        init.xavier_uniform_(self.w)

        #
        if relu:
            self.relu = nn.ReLU()
        else:
            self.relu = None

        # giving preference for relu over leaky relu
        if leaky_relu and not relu:
            self.relu = nn.LeakyReLU(0.2)

        if dropout:
            self.dropout = nn.Dropout(0.5)
        else:
            self.dropout = None

        # TODO: have a normalize option

    def forward(self, nodes, inp_rep=None):
        """
        Generates embeddings for a batch of nodes.

        nodes     -- list of nodes
        """
        nodes_list = [int(node) for node in nodes]
        neigh_nodes = [self.adj_lists[node] for node in nodes_list]
        neigh_feats = self.aggregator(nodes_list, neigh_nodes, inp_rep)

        if not self.gcn:
            if type(nodes) == list:
                nodes = torch.tensor(nodes).to(self.device)
            self_feats = self.features(nodes)
            combined = torch.cat([self_feats, neigh_feats], dim=1)
        else:
            combined = neigh_feats

        output = torch.matmul(combined, self.w) + self.b

        if self.dropout is not None:
            output = self.dropout(output)

        if self.relu is not None:
            output = self.relu(output)

        output = F.normalize(output)

        return output

