import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import GraphConv

from hadamard import hadamard_transform_cuda, hadamard_transform_torch

sigma = 1e-6

def rff_transform(embedding, w):
    D = w.size(1)
    out = torch.mm(embedding, w)
    d1 = torch.cos(out)
    d2 = torch.sin(out)
    return np.sqrt(1 / D) * torch.cat([d1, d2], dim=1)


def sorf_transform(embedding, out_dim, temp):
    in_dim = embedding.size(1)
    N = out_dim // in_dim
    tau = np.sqrt(temp)
    W = []

    device = embedding.device
    if embedding.device == 'cpu':
        hadamard_transform = hadamard_transform_torch
    else:
        hadamard_transform = hadamard_transform_cuda

    for i in range(N):
        D = torch.randint(0, 2, (3, embedding.shape[1])).float().to(device)
        D = 2 * D - 1

        x = hadamard_transform(embedding, normalize=True)
        x = hadamard_transform(D[0] * x, normalize=True)
        x = hadamard_transform(D[1] * x, normalize=True)

        x = D[2] * x * np.sqrt(in_dim) / tau
        W.append(x)
    W = torch.cat(W, dim=1)

    return torch.cat([torch.cos(W), torch.sin(W)], dim=1) / np.sqrt(out_dim)

class LogReg(nn.Module):
    def __init__(self, hid_dim, out_dim):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(hid_dim, out_dim)

    def forward(self, x):
        ret = self.fc(x)
        return ret


class MLP(nn.Module):
    def __init__(self, nfeat, nhid, nclass, use_bn=True):
        super(MLP, self).__init__()

        self.layer1 = nn.Linear(nfeat, nhid, bias=True)
        self.layer2 = nn.Linear(nhid, nclass, bias=True)

        self.bn = nn.BatchNorm1d(nhid)
        self.use_bn = use_bn
        self.act_fn = nn.ReLU()

    def forward(self, _, x):
        x = self.layer1(x)
        if self.use_bn:
            x = self.bn(x)

        x = self.act_fn(x)
        x = self.layer2(x)

        return x

class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, n_layers, use_ln=False):
        super().__init__()

        self.n_layers = n_layers
        self.convs = nn.ModuleList()
        self.convs.append(GraphConv(in_dim, hid_dim, norm='both'))
        self.use_ln = use_ln
        self.lns = nn.ModuleList()

        if n_layers > 1:
            for i in range(n_layers - 2):
                self.convs.append(GraphConv(hid_dim, hid_dim, norm='both'))
            for i in range(n_layers - 1):
                self.lns.append(nn.BatchNorm1d(hid_dim))
                # self.lns.append(nn.LayerNorm(hid_dim))
            self.convs.append(GraphConv(hid_dim, out_dim, norm='both'))

    def forward(self, graph, x):

        for i in range(self.n_layers - 1):
            if not self.use_ln:
                x = F.relu(self.convs[i](graph, x))
            else:
                x = F.relu(self.lns[i](self.convs[i](graph, x)))

        x = self.convs[-1](graph, x)

        return x

class Model(nn.Module):

    def __init__(self, in_dim, hid_dim, out_dim, rff_dim, num_layers, temp, loss='sorf', use_mlp=False):
        super(Model, self).__init__()
        if use_mlp:
            self.encoder = MLP(in_dim, hid_dim, out_dim, use_bn=True)
        else:
            self.encoder = GCN(in_dim, hid_dim, out_dim, num_layers)

        self.temp = temp
        self.out_dim = out_dim
        self.rff_dim = rff_dim
        self.loss = loss

    def get_embedding(self, graph, feat):
        # get embeddings from the model for evaluation
        h = self.encoder(graph, feat)

        return h.detach()

    def pos_score(self, graph, h):
        graph = graph.remove_self_loop()
        graph.ndata['z'] = F.normalize(h, dim=-1)
        graph.apply_edges(fn.u_mul_v('z', 'z', 'sim'))

        graph.edata['sim'] = torch.exp((graph.edata['sim'].sum(1)) / self.temp)
        graph.update_all(fn.copy_e('sim', 'm'), fn.mean('m', 'pos'))

        pos_score = graph.ndata['pos']

        return pos_score

    def neg_score(self, h, rff_dim=None):
        z = F.normalize(h, dim=-1)

        if self.loss == 'infonce':
            neg_sim = torch.exp(torch.mm(z, z.t().contiguous()) / self.temp)
            neg_score = neg_sim.sum(1)
        else:
            if self.loss == 'rff':
                w = torch.randn(z.size(1), rff_dim).to(z.device) / np.sqrt(self.temp)
                fft_out = rff_transform(z, w)
            elif self.loss == 'sorf':
                fft_out = sorf_transform(z, rff_dim, self.temp)

            neg_sum = torch.sum(fft_out, dim=0, keepdim=True)
            neg_score = np.exp(1 / self.temp) * (torch.sum(fft_out * neg_sum, dim=1))

        return neg_score

    def forward(self, graph, feat):
        # encoding
        h = self.encoder(graph, feat)

        pos_score = self.pos_score(graph, h)
        neg_score = self.neg_score(h, self.rff_dim)

        loss = - torch.log((pos_score + sigma) / neg_score).mean()

        return loss
