import copy
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv
from torch_geometric.nn import VGAE
from torch_geometric.utils import to_undirected, add_self_loops, remove_self_loops, negative_sampling, subgraph
    

class GIN_NodeWeightEncoder(torch.nn.Module):
    def __init__(self, dataset, dim, add_mask=False):
        super().__init__()
        num_features = dataset.num_features
        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
        self.conv1 = GINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn5 = None
        if add_mask == True:
            nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, 3))
            self.conv5 = GINConv(nn5)
            self.bn5 = torch.nn.BatchNorm1d(3)
        else:
            nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, 2))
            self.conv5 = GINConv(nn5)
            self.bn5 = torch.nn.BatchNorm1d(2)
    
    def forward(self, data):        
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        return x


class ViewGenerator(VGAE):
    def __init__(self, dataset, dim, encoder, add_mask=False):
        self.add_mask = add_mask
        encoder = encoder(dataset, dim, self.add_mask)
        super().__init__(encoder=encoder)

    def sample_view(self, data):
        data = copy.deepcopy(data)
        edge_index = data.edge_index
        z = self.encode(data)
        
        recovered = self.decoder.forward_all(z)
        exp_num = recovered.sum()
        recovered = self.decoder.forward_all(z) * (data.num_edges / float(exp_num)) 
        edge_selected = torch.bernoulli(recovered)
        edge_selected = edge_selected.bool()
        
        edge_index = edge_selected.nonzero(as_tuple=False).T
        edge_index = to_undirected(edge_index)
        edge_index = add_self_loops(edge_index)[0]
        data.edge_index = edge_index
        return z, recovered, data

    def sample_partial_view(self, data):
        data = copy.deepcopy(data)
        z = self.encode(data)
        edge_index = data.edge_index

        neg_edge_index = negative_sampling(edge_index)
        joint_edge_index = torch.cat((edge_index, neg_edge_index), dim=1)
        joint_edge_index = remove_self_loops(joint_edge_index)[0]

        wanted_num_edges = data.num_edges // 2
        edge_weights = self.decoder.forward(z, joint_edge_index)
        exp_num_edges = edge_weights.sum()
        edge_weights *=  wanted_num_edges / exp_num_edges
        
        edge_selected = torch.bernoulli(edge_weights)
        edge_selected = edge_selected.bool()

        edge_index = joint_edge_index[:, edge_selected]
        edge_index = to_undirected(edge_index)
        edge_index = remove_self_loops(edge_index)[0]

        data.edge_index = edge_index
        return z, None, data
    
    def sample_partial_view_recon(self, data, neg_edge_index):
        data = copy.deepcopy(data)
        z = self.encode(data)
        
        edge_index = data.edge_index
        
        if neg_edge_index == None:
            neg_edge_index = negative_sampling(edge_index)
        
        joint_edge_index = torch.cat((edge_index, neg_edge_index), dim=1)
        edge_weights = self.decoder.forward(z, joint_edge_index)
        edge_selected = torch.bernoulli(edge_weights)
        edge_selected = edge_selected.bool()

        edge_index = joint_edge_index[:, edge_selected]
        edge_index = to_undirected(edge_index)
        data.edge_index = edge_index
        return z, neg_edge_index, data

    def sample_subgraph_view(self, data):
        data = copy.deepcopy(data)
        z = self.encode(data)
        edge_index = data.edge_index

        recovered_all = self.decoder.forward_all(z)
        recovered = self.decode(z, edge_index)
        edge_selected = torch.bernoulli(recovered)
        edge_selected = edge_selected.bool()
        edge_index = edge_index[:, edge_selected]
        edge_index = to_undirected(edge_index)

        edge_index = add_self_loops(edge_index, num_nodes = data.num_nodes)[0]

        data.edge_index = edge_index
        return z, recovered_all, data
    
    def forward(self, data_in, requires_grad):
        data = copy.deepcopy(data_in)
        
        x, edge_index = data.x, data.edge_index
        edge_attr = None
        if data.edge_attr is not None:
            edge_attr = data.edge_attr

        data.x = data.x.float()
        x = x.float()
        x.requires_grad = requires_grad
        
        p = self.encoder(data)
        sample = F.gumbel_softmax(p, hard=True)

        real_sample = sample[:,0]
        attr_mask_sample = None
        if self.add_mask == True:
            attr_mask_sample = sample[:,2]
            keep_sample = real_sample + attr_mask_sample
        else:
            keep_sample = real_sample
        
        keep_idx = torch.nonzero(keep_sample, as_tuple=False).view(-1,)
        edge_index, edge_attr = subgraph(keep_idx, edge_index, edge_attr, num_nodes=data.num_nodes)
        x = x * keep_sample.view(-1, 1)

        if self.add_mask == True:
            attr_mask_idx = attr_mask_sample.bool()
            token = data.x.detach().mean()
            x[attr_mask_idx] = token

        data.x = x
        data.edge_index = edge_index
        if data.edge_attr is not None:
            data.edge_attr = edge_attr
        
        return keep_sample, data
    

class ViewLearner(torch.nn.Module):
	def __init__(self, encoder, emb_dim, mlp_edge_model_dim=64):
		super(ViewLearner, self).__init__()
		self.encoder = encoder
		self.input_dim = emb_dim
		self.mlp_edge_model = Sequential(
			Linear(self.input_dim * 2, mlp_edge_model_dim),
			ReLU(),
			Linear(mlp_edge_model_dim, 1)
		)
		self.init_emb()

	def init_emb(self):
		for m in self.modules():
			if isinstance(m, Linear):
				torch.nn.init.xavier_uniform_(m.weight.data)
				if m.bias is not None:
					m.bias.data.fill_(0.0)

	def forward(self, batch, x, edge_index):
		_, node_emb = self.encoder(batch, x, edge_index)
		src, dst = edge_index[0], edge_index[1]
		emb_src = node_emb[src]
		emb_dst = node_emb[dst]
		edge_emb = torch.cat([emb_src, emb_dst], 1)
		edge_logits = self.mlp_edge_model(edge_emb)
		return edge_logits