import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import cal_similarity_graph, top_k, apply_non_linearity


class MLP(nn.Module):
    def __init__(self, in_channel, out_channel, hidden):
        super().__init__()
        self.l1 = nn.Linear(in_channel, hidden)
        self.l2 = nn.Linear(hidden, hidden)
        self.l3 = nn.Linear(hidden, out_channel)

    def forward(self, x):
        x = self.l1(x)
        # x = torch.relu(x)
        # x = self.l2(x)
        x = torch.relu(x)
        x = self.l3(x)
        return x


class GCNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, A_hat):
        out = torch.mm(A_hat, self.lin(x))
        return out


class GCN(nn.Module):
    def __init__(self, in_channel, out_channel, hidden):
        super().__init__()
        self.gcn1 = GCNLayer(in_channel, hidden)
        self.gcn2 = GCNLayer(hidden, out_channel)

    def forward(self, x, A_hat):
        x = self.gcn1(x, A_hat)
        x = F.relu(x)
        x = self.gcn2(x, A_hat)
        return x


class Diag(nn.Module):
    def __init__(self, input_size):
        super(Diag, self).__init__()
        self.W = nn.Parameter(torch.ones(input_size))
        self.input_size = input_size

    def forward(self, input):
        hidden = input @ torch.diag(self.W)
        return hidden


class MLP_Diag(nn.Module):
    def __init__(self, nlayers, isize, k, knn_metric, non_linearity, i, mlp_act):
        super(MLP_Diag, self).__init__()

        self.i = i
        self.layers = nn.ModuleList()
        for _ in range(nlayers):
            self.layers.append(Diag(isize))
        self.k = k
        self.knn_metric = knn_metric
        self.non_linearity = non_linearity
        self.mlp_act = mlp_act

    def internal_forward(self, h):
        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != (len(self.layers) - 1):
                if self.mlp_act == "relu":
                    h = F.relu(h)
                elif self.mlp_act == "tanh":
                    h = F.tanh(h)
        return h

    def forward(self, features):
        embeddings = self.internal_forward(features)
        embeddings = F.normalize(embeddings, dim=1, p=2)
        similarities = cal_similarity_graph(embeddings)
        similarities = top_k(similarities, self.k + 1)
        # similarities = (similarities > 0.7).int()
        similarities = apply_non_linearity(similarities, self.non_linearity, self.i)
        return similarities


class Linear(nn.Module):
    def __init__(self, nlayers, isize, k, knn_metric, non_linearity, i, mlp_act):
        super(Linear, self).__init__()

        self.i = i
        self.layers = nn.ModuleList()
        for _ in range(nlayers):
            self.layers.append(nn.Linear(isize, isize))
        self.k = k
        self.knn_metric = knn_metric
        self.non_linearity = non_linearity
        self.mlp_act = mlp_act

    def internal_forward(self, h):
        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != (len(self.layers) - 1):
                if self.mlp_act == "relu":
                    h = F.relu(h)
                elif self.mlp_act == "tanh":
                    h = F.tanh(h)
        return h

    def forward(self, features):
        embeddings = self.internal_forward(features)
        embeddings = F.normalize(embeddings, dim=1, p=2)
        similarities = cal_similarity_graph(embeddings)
        similarities = top_k(similarities, self.k + 1)
        similarities = apply_non_linearity(similarities, self.non_linearity, self.i)
        return similarities