import torch
import torch.nn as nn
import dgl
import torch.nn.functional as F
from pooling.sagpool import SAGPoolBlock

"""
    GCN: Graph Convolutional Networks
    Thomas N. Kipf, Max Welling, Semi-Supervised Classification with Graph Convolutional Networks (ICLR 2017)
    http://arxiv.org/abs/1609.02907
"""
from layers.gcn_layer import GCNLayer
from layers.mlp_readout_layer import MLPReadout


class GCNNet(nn.Module):
    def __init__(self, net_params):
        super().__init__()
        in_dim = net_params['in_dim']
        hidden_dim = net_params['hidden_dim']
        out_dim = net_params['out_dim']
        self.n_classes = net_params['n_classes']
        in_feat_dropout = net_params['in_feat_dropout']
        dropout = net_params['dropout']
        n_layers = net_params['L']
        self.cluster = net_params['cluster']
        self.readout = net_params['readout']
        self.batch_norm = net_params['batch_norm']
        self.residual = net_params['residual']
        
        self.embedding_h = nn.Linear(in_dim, hidden_dim)
        self.in_feat_dropout = nn.Dropout(in_feat_dropout)
        
        self.layers = nn.ModuleList([GCNLayer(hidden_dim, hidden_dim, F.relu, dropout, self.batch_norm, self.residual)
                                     for _ in range(n_layers-1)])
        self.layers.append(GCNLayer(hidden_dim, out_dim, F.relu, dropout, self.batch_norm, self.residual))
        self.MLP_layer = MLPReadout(out_dim, self.n_classes)

        if self.cluster:
            self.cluster_info = {'num': [0 for i in range(self.n_classes)],
                                 'repr': nn.ParameterList([torch.nn.Parameter(torch.rand(out_dim, device=net_params['device']))
                                                           for i in range(self.n_classes)]),
                                 'std': [torch.tensor(0.0, device=net_params['device'])
                                                          for i in range(self.n_classes)]}
            self.graph_repr = None
            self.cos_sim = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
            self.sim = torch.nn.CosineSimilarity(dim=0, eps=1e-08)
            self.mlp = nn.Linear(out_dim * 2, out_dim)
            self.mlp_cluster = nn.Linear(out_dim, out_dim)
            self.pool = SAGPoolBlock(in_dim=out_dim)

    def update_cluster(self, batch_labels, use_gt=False):
        labels = torch.argmax(batch_labels, dim=1) if not use_gt else batch_labels.float()
        labels = labels.unsqueeze(dim=1)
        labels = (batch_labels + (labels - batch_labels).detach()).mean(dim=1)
        cls_idx = [torch.squeeze((labels == i).nonzero()) for i in range(self.n_classes)]

        for i in range(self.n_classes):
            if not cls_idx[i].numel():
                self.cluster_info['std'][i] = torch.tensor(0.0, device=batch_labels.device)
                continue
            if cls_idx[i].shape == torch.Size([]):
                cls_idx[i] = torch.unsqueeze(cls_idx[i], dim=0)
            if self.cluster_info['num'][i] == 0:
                self.cluster_info['repr'][i].data = torch.mean(self.graph_repr[cls_idx[i]], dim=0)
                self.cluster_info['num'][i] += len(cls_idx[i])
            else:
                sum_repr = self.cluster_info['repr'][i] * self.cluster_info['num'][i] + torch.sum(self.graph_repr[cls_idx[i]], dim=0)
                self.cluster_info['num'][i] += len(cls_idx[i])
                self.cluster_info['repr'][i].data = sum_repr.clone()
                self.cluster_info['repr'][i].data = self.cluster_info['repr'][i] / self.cluster_info['num'][i]
            stds = self.cos_sim(self.graph_repr[cls_idx[i]], self.cluster_info['repr'][i].repeat(self.graph_repr[cls_idx[i]].shape[0], 1))
            stds = 0.5 * (1 + stds)
            self.cluster_info['std'][i] = stds.mean().clone()

        total_neg_loss, total_pos_loss = torch.tensor(0.0, device=batch_labels.device), torch.tensor(0.0, device=batch_labels.device)
        for i in range(self.n_classes):
            if self.cluster_info['num'][i] == 0:
                continue
            for j in range(i + 1, self.n_classes):
                if self.cluster_info['num'][j] == 0:
                    continue
                neg_loss = self.sim(self.cluster_info['repr'][i], self.cluster_info['repr'][j])
                total_neg_loss += 0.5 * (1 + neg_loss)
        pos_loss = sum(self.cluster_info['std']) / len(self.cluster_info['std'])
        total_pos_loss += pos_loss[0] if pos_loss.size() == torch.Size([1]) else pos_loss

        if total_neg_loss == 0 or total_pos_loss == 0:
            total_neg_loss, total_pos_loss = None, None
        else:
            total_neg_loss = total_neg_loss / (self.n_classes * (self.n_classes - 1) / 2)
        return total_neg_loss, total_pos_loss

    def forward(self, g, h, e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        for conv in self.layers:
            h = conv(g, h)

        g.ndata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        if self.cluster:
            _, _, shg = self.pool(g, h)
            self.graph_repr = shg.clone()

            sim_list = []
            for i in range(self.n_classes):

                if self.cluster_info['num'][i] == 0:
                    sim = -torch.ones(hg.shape[0], device=hg.device)
                else:
                    sim = self.cos_sim(hg, self.cluster_info['repr'][i].repeat(hg.shape[0], 1))

                sim_list.append(sim)

            sim_tensor = torch.stack(sim_list).to(hg.device)
            class_idxs = torch.argmax(sim_tensor, dim=0)
            idxs = [(class_idxs == i).long() for i in range(self.n_classes)]

            clusters = (torch.stack([idxs[i] * self.cluster_info['repr'][i].repeat(hg.shape[0], 1).T
                                     for i in range(self.n_classes)]).sum(dim=0)).T
            clusters = hg + (clusters - hg).detach()
            clusters = F.relu(self.mlp_cluster(clusters))
            # hg = torch.cat([hg, clusters], dim=1)
            # hg = F.relu(self.mlp(hg))
            hg = hg + clusters

        scores = self.MLP_layer(hg)
        return scores
    
    def loss(self, pred, label):

        criterion = nn.CrossEntropyLoss()
        loss = criterion(pred, label)
        return loss
