import torch
from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, Parameter
import torch.nn.functional as F
from torch_geometric.nn.inits import reset
from typing import List, Optional, Set, Callable, get_type_hints
from torch_geometric.utils import (negative_sampling, remove_self_loops,
                                   add_self_loops)
from torch_geometric.typing import Adj, Size
from torch import Tensor

from torch.nn import Sequential, Linear, LeakyReLU
from torch_geometric.nn import GINConv, global_add_pool, global_mean_pool, GCNConv
import torch.nn as nn
from model.encoder import *


class GAE(torch.nn.Module):
    def __init__(self, encoder, decoder=None, r=10.0):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.r = Parameter(torch.Tensor(1))
        GAE.reset_parameters(self, r)

    def reset_parameters(self, r):
        reset(self.encoder)
        reset(self.decoder)

    def encode(self, *args, **kwargs):
        r"""Runs the encoder and computes node-wise latent variables."""
        return self.encoder(*args, **kwargs)
    

    def decode(self, *args, **kwargs):
        r"""Runs the decoder and computes edge probabilities."""
        return self.decoder(*args, **kwargs)

    def recon_loss(self, z, adj, pos_edge_index, neg_edge_index=None, x=None, loss_type='l2', beta=1.0):
        
        loss_function = F.l1_loss if loss_type == 'l1' else F.mse_loss
        
        # positive
        pred_pos_struct, pred_pos_attr = self.decoder(z, pos_edge_index)
        pred_pos_struct = pred_pos_struct.view(-1)
        gt_pos_struct = adj[pos_edge_index[0], pos_edge_index[1]]
        pos_loss_struct = F.l1_loss(
            pred_pos_struct,
            gt_pos_struct
        )
        loss_attr = ((pred_pos_attr - x) ** 2).mean(dim=1)
        
        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0), num_neg_samples=len(pos_edge_index[0]))
            neg_edge_index = neg_edge_index.to(z.device)
            
        pred_neg_struct, pred_neg_attr = self.decoder(z, neg_edge_index)
        pred_neg_struct = pred_neg_struct.view(-1)
        gt_neg_struct = adj[neg_edge_index[0], neg_edge_index[1]]

        neg_loss_struct = F.l1_loss(
            pred_neg_struct, 
            gt_neg_struct
        )
        
        return pos_loss_struct, neg_loss_struct, loss_attr
    
    

class GraphConv(torch.nn.Module):
    def __init__(self):
        super(GraphConv, self).__init__()

    def forward(self, x, adj):
        x = torch.mm(adj, x)
        return x
    


class RBF(nn.Module):

    def __init__(self, device, n_kernels=5, mul_factor=2.0, bandwidth=None):
        super().__init__()
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2).to(device)
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)


class MMDLoss(nn.Module):

    def __init__(self,device):
        super().__init__()
        kernel=RBF(device)
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY
    