import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import negative_sampling

from utils.others import visualize_tsne, visualize_hist, visualize_features


class PretrainModel(nn.Module):
    def __init__(self, encoder, vq, feat_recon_decoder, topo_recon_decoder):
        super().__init__()
        self.encoder = encoder
        self.vq = vq

        self.feat_recon_decoder = feat_recon_decoder
        self.topo_recon_decoder = topo_recon_decoder

    def save_encoder(self, path):
        torch.save(self.encoder.state_dict(), path)

    def save_vq(self, path):
        torch.save(self.vq.state_dict(), path)

    def feat_recon_loss(self, h, x, bs):
        x_recon = self.feat_recon_decoder(h[:bs])
        return F.mse_loss(x_recon, x[:bs])

    def topo_recon_loss(self, z, pos_edge_index, neg_edge_index=None, ratio=1.0):            
        if ratio == 0.0:
            return torch.tensor(0.0, device=z.device)

        if ratio != 1.0:
            num_pos_edges = int(pos_edge_index.size(1) * ratio)
            num_pos_edges = max(num_pos_edges, 1)
            perm = torch.randperm(pos_edge_index.size(1))
            perm = perm[:num_pos_edges]
            pos_edge_index = pos_edge_index[:, perm]

        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))

        pos_loss = -torch.log(self.topo_recon_decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()
        neg_loss = -torch.log(1 - self.topo_recon_decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()

        return pos_loss + neg_loss

    def forward(self, aug_g, g, topo_recon_ratio, bs):
        x, edge_index, edge_attr = aug_g[0], aug_g[1], aug_g[2]
        orig_x, orig_edge_index, orig_edge_attr, field = g[0], g[1], g[2], g[3]

        h = self.encoder(x, edge_index, edge_attr)
        query, (contrastive_loss, field_loss) = self.vq(h, field)

        # visualize_tsne(h[:bs], field[:bs])
        # visualize_hist(query[:bs])
        # visualize_features(orig_x[:bs], self.feat_recon_decoder(query[:bs]))
        # raise NotImplementedError

        feat_recon_loss = self.feat_recon_loss(query, orig_x, bs)
        topo_recon_loss = self.topo_recon_loss(query, orig_edge_index, ratio=topo_recon_ratio)

        losses = {
            'feat_recon_loss': feat_recon_loss,
            'topo_recon_loss': topo_recon_loss,
            'contrastive_loss': contrastive_loss,
            'field_loss': field_loss
        }

        return losses
