# -*- coding: utf-8 -*-
# @File : high_layers.py
# @Author : 王军
# @Time : 2022/10/7 9:49
# @Software : PyCharm
import torch.nn.functional as F
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

class HypergraphContrast(nn.Module):
    def __init__(self,num_nodes,latdim,device):
        super(HypergraphContrast, self).__init__()
        self.num_nodes = num_nodes
        self.h_num_nodes = 32
        self.h_dim = 16
        self.act1 = nn.LeakyReLU()
        self.emb_nodes = nn.Parameter(torch.randn(num_nodes, self.h_dim).to(device),
                                requires_grad=True).to(device)
        self.emb_h = nn.Parameter(torch.randn(self.h_dim, self.h_num_nodes).to(device),
                                      requires_grad=True).to(device)
        # self.adj = nn.Parameter(torch.randn(num_nodes, self.h_num_nodes).to(device),
        #                         requires_grad=True).to(device)
        self.readout = AvgReadout()
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(latdim=latdim)

    def encoder(self,x):
        residual = x
        adj = F.relu(torch.einsum('nd,dh->nh',self.emb_nodes,self.emb_h))
        adj_norm = F.softmax(adj, dim=-1)
        tpadj_norm = F.softmax(adj.transpose(1, 0), dim=-1)
        #tpadj_norm = adj.transpose(1, 0)
        hyper_x = self.act1(torch.einsum('sh,bfst->bfht', adj_norm, x))
        ret_x = self.act1(torch.einsum('hs,bfht->bfst', tpadj_norm, hyper_x))
        ret_x += residual
        return ret_x

    def forward(self, pos_x,neg_x):
        neg_ = None
        ret = None
        pos_ = self.encoder(pos_x)#64 x 32 x 207
        if neg_x is not None:
            neg_ = self.encoder(neg_x)
        score = self.readout(pos_)
        if neg_x is not None:
            ret = self.disc(score, pos_, neg_)
        return pos_,ret

# Hypergraph Infomax Discriminator
class Discriminator(nn.Module):
    def __init__(self,latdim):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(latdim, latdim, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, score, h_pos, h_neg):
        h_pos = torch.squeeze(h_pos)
        h_neg = torch.squeeze(h_neg)
        score = torch.unsqueeze(score, -1)
        score = score.expand_as(h_pos)
        score = score.transpose(1, 2).contiguous()
        h_pos = h_pos.transpose(1, 2).contiguous()
        h_neg = h_neg.transpose(1, 2).contiguous()
        sc_pos = torch.squeeze(self.f_k(h_pos, score), -1)
        sc_neg = torch.squeeze(self.f_k(h_neg, score), -1)
        logits = torch.cat((sc_pos, sc_neg), dim=-1)
        return logits#torch.Size([64, 207])
# Hypergraph Infomax AvgReadout
class AvgReadout(nn.Module):
    def __init__(self):
        super(AvgReadout, self).__init__()

    def forward(self, embeds):
        return torch.mean(torch.mean(embeds, dim=-1),dim=2)