import torch
import torch_geometric as pyg
import numpy as np

class SingleLayerGeneralGNN(torch.nn.Module):
    def __init__(self, background_gnn, metagraph_gnn, deepset_module, label_mlp=torch.nn.Identity,
                 input_mlp=torch.nn.Identity, params=None):
        '''
        Single layer of the whole metagraph GNN.
        :param background_gnn:
        :param metagraph_gnn:
        :param deepset_module:
        :param label_mlp: MLP to project the label embeddings before
        :param input_mlp: MLP to project inputs to an embedding space before final prediction.
        '''
        super().__init__()
        self.bg_gnn = background_gnn          #  For passing messages between the original (sub)graphs and supernodes
        self.deepset_module = deepset_module  #  For aggregating (sub)graph embeddings ("level 2" aggregation)
        self.metagraph_gnn = metagraph_gnn    #  For passing messages between the task nodes and pooled subgraph representations (back)
        self.cos = torch.nn.CosineSimilarity(dim=1)
        self.label_mlp = label_mlp
        self.input_mlp = input_mlp
        if params is not None:
            self.params = params
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def decode(self, input_x, label_x, metagraph_edge_index, edgelist_bipartite=False):
        '''
        :param input_x: As returned by the forward() method.
        :param label_x: As returned by the forward() method.
        :param edgelist_bipartite: Whether edgelist is bipartite, i.e. both left and right side are numbered from 0.
        :return:
        '''
        if edgelist_bipartite:
            ind0 = metagraph_edge_index[0, :]
            ind1 = metagraph_edge_index[1, :]
            decoded_logits = (self.cos(input_x[ind0], label_x[ind1]) + 1) / 2
            return decoded_logits
        x = torch.cat((input_x, label_x))
        ind0 = metagraph_edge_index[0, :]
        ind1 = metagraph_edge_index[1, :]
        decoded_logits = self.cos(x[ind0], x[ind1]) * self.logit_scale.exp()
        # decoded_logits = torch.sum(x[ind0]* x[ind1], 1) * self.logit_scale.exp()

        return decoded_logits

    def forward_background_graph(self, graphs_batched, second_pooling_mapping):
        #  Forward pass of the background GNN. Returns supernode embeddings.
        #  Gets the pooled (sub)graph embedding matrix by using the GNN1.
        supernode_idx = graphs_batched.supernode + graphs_batched.ptr[:-1]
        supernode_embeddings = self.bg_gnn(x=graphs_batched.x, edge_index=graphs_batched.edge_index,
                                           edge_attr=graphs_batched.edge_attr,
                                           supernode_edge_index=graphs_batched.edge_index_supernode,
                                           supernode_idx=supernode_idx
                                           )
        #pooled_embeddings = torch.stack(
        #    [self.deepset_module(supernode_embeddings[subset]) for subset in second_pooling_mapping])

        return supernode_embeddings

    def forward_metagraph(self, supernode_x, label_x, metagraph_edge_index, metagraph_edge_attr):
        '''
        Forward pass on the graph embedding <-> task bipartite metagraph.
        supernode_x: output from forward1 - matrix of pooled (sub)graph embeddings.
        label_x: matrix of label embeddings - either generated by BERT (previous step) or output from previous
                 SingleLayerGeneralGNN
        metagraph_edge_index: edge_index of a directed bipartite graph mapping class embedding index to class idx.
                              Both left and right side start counting node idx from 0!
        metagraph_edge_attr: edge_attr of the metagraph
        :return: Updated pooled embeddings and task embeddings of the bipartite graph.
        '''

        #  Here we assume that class embeddings have the same shape as the pooled subgraph embeddings...
        #  (e.g. apply a NN on the BERT output etc. ...)

        # supernode_x = self.get_supernode_embeddings(graphs, )
        supernode_x_proj = self.input_mlp(supernode_x)
        label_x_proj = self.label_mlp(label_x)
        if self.params['ignore_label_embeddings']:
            label_x_proj = torch.zeros(label_x_proj.shape).float().to(label_x_proj.device)
        x = torch.cat((supernode_x_proj, label_x_proj))
        if not self.params['zero_shot']:
            x = self.metagraph_gnn(x=x, edge_index=metagraph_edge_index, edge_attr=metagraph_edge_attr,
                                   start_right=supernode_x_proj.shape[0])

        input_x_mg = x[:supernode_x_proj.shape[0]]
        label_x_mg = x[supernode_x_proj.shape[0]:]

        assert len(input_x_mg) == len(supernode_x_proj)
        assert len(label_x_mg) == len(label_x_proj)

        return input_x_mg, label_x_mg


    def forward(self, graph, label_x, y_true_matrix, metagraph_edge_index, metagraph_edge_attr, query_set_mask):
        '''
        Params as returned by the batching function.
        :return: y_true_matrix, y_pred_matrix (for the query set only!)
        '''
        x_supernodes = self.forward_background_graph(graph, None)
        x_input, x_label = self.forward_metagraph(x_supernodes, label_x, metagraph_edge_index, metagraph_edge_attr)

        y_pred_matrix = self.decode(x_input, x_label, metagraph_edge_index, edgelist_bipartite=False).reshape(
                                         y_true_matrix.shape)
        qry_idx = torch.where(query_set_mask.reshape(-1, y_true_matrix.shape[1])[:, 0] == 1)[0]
        return y_true_matrix[qry_idx, :], y_pred_matrix[qry_idx, :]


class SingleLayerSimpleEncoderGNN(torch.nn.Module):
    '''
    Used temporarily for debugging node classification dataset. If n_class = None, it will output node embeddings.
    '''

    def __init__(self, gnn1, emb_dim=128, n_class=None):
        super().__init__()
        self.gnn1 = gnn1
        if n_class is None:
            self.proj_final = None
        else:
            self.loss = torch.nn.CrossEntropyLoss()
            self.proj_final = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(),
                                                  torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.ReLU(),
                                                  torch.nn.Linear(emb_dim, n_class))

    def forward(self, graph_batch):
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        x = self.gnn1(x=x, edge_index=edge_index, edge_attr=None)
        center_node_embs = x[graph_batch.center_node + graph_batch.ptr[:-1]]
        if self.proj_final is None:
            return center_node_embs
        else:
            return self.proj_final(center_node_embs)

