from cmath import log
import torch

from layers import *

import torch.nn.functional as F
from torch.nn import BatchNorm1d as BatchNorm
from torch.nn import Linear, ReLU, Sequential
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.utils import to_dense_adj

from metrics import calculate_erank, calculate_patch_sim, cal_variance
from collections import defaultdict
# from layers import create_label_induced_negative_graph_sparse, NormLayer

def log_metrics(metrics, x, pos, cal_erank=False):
    metrics[f'sim{pos}'].append(calculate_patch_sim(x)) 
    metrics[f'erank{pos}'].append(calculate_erank(x) if cal_erank else 0.) 
    metrics[f'var{pos}'].append(cal_variance(x))
    return metrics


def undirect_to_direct(adj, p=0.5):
    adj = adj.to_dense()
    mask_full = torch.rand_like(adj) > p
    mask_low = torch.tril(mask_full, diagonal=-1)
    mask_upper = torch.triu((~mask_low).T, diagonal=1)
    mask_all = torch.logical_or(mask_low, mask_upper)
    adj.masked_fill_(mask_all, 0)
    return adj.to_sparse()


def dropedge(adj, p=0.5):
    adj = adj.coalesce()
    # import pdb; pdb.set_trace()
    indices = adj.indices()
    n_node = len(indices[0])
    chosen_idx = torch.randperm(n_node, device=indices.device)[:int(p*n_node)]
    adj = torch.sparse_coo_tensor(indices[:, chosen_idx], adj.values()[chosen_idx], size=adj.size())
    return adj


class DeepGCN(nn.Module):
    def __init__(self, args, nfeat, nclass, **kwargs):
        super(DeepGCN, self).__init__()
        assert args.nlayer >= 1 

        self.hidden_layers = nn.ModuleList([
            GraphConv(nfeat if i==0 else args.hid, args.hid, args.norm_mode) 
            for i in range(args.nlayer-1)
        ])
        self.out_layer = GraphConv(nfeat if args.nlayer==1 else args.hid , nclass, args.norm_mode)

        self.dropout = nn.Dropout(p=args.dropout)
        self.dropout_rate = args.dropout
        self.relu = nn.ReLU(True)
        self.norm = NormLayer(args)
        self.norm_scale = args.norm_scale
        self.norm_mode = args.norm_mode
        self.neg_wei = args.neg_weight
        self.layer_norm = nn.LayerNorm(args.hid)
        # self.skip = args.residual
        self.use_layer_norm = args.use_layer_norm

    def forward(self, x, adj, labels, train_mask, cal_erank=False, cal_metrics=False):

        x_old = 0
        metrics = defaultdict(list)
        # metrics = log_metrics(metrics, x, '1', cal_erank=cal_erank) if cal_metrics else metrics
        if self.norm_mode in ['Sign']:
            norm_x = nn.functional.normalize(x, dim=1)
            sim = - torch.mm( norm_x, norm_x.T)
            if adj.size(1) == 2:
                sim[adj[0], adj[1]] = -np.inf
            else:
                sim.masked_fill_(adj.to_dense() > 1e-5, -np.inf)
            sim = nn.functional.softmax(sim, dim=1)
            adj = - self.neg_wei* sim + adj
        elif self.norm_mode in ['Label']:
            neg = create_label_induced_negative_graph_sparse( labels, train_mask)
            if labels.size(0) > 10000:
                neg = torch.sparse.softmax(neg,dim=1)
            else: 
                neg = nn.functional.softmax(neg.to_dense(), dim=1)
            adj = - self.neg_wei* neg + adj
        elif self.norm_mode in ['drop']:
            adj = dropedge(adj, self.norm_scale)
        for i, layer in enumerate(self.hidden_layers):
            x = self.dropout(x)
            x = layer(x, adj)
            # metrics = log_metrics(metrics, x, '2', cal_erank=cal_erank) if cal_metrics else metrics
            if self.norm_mode in ['Sign', 'Label', 'res', 'drop']:    
                x = x
            else:
                x = self.norm(x, adj, labels, train_mask)
            # metrics = log_metrics(metrics, x, '3', cal_erank=cal_erank) if cal_metrics else metrics


            if self.use_layer_norm:
                x = self.layer_norm(x)
            
            x = self.relu(x)
            
            if self.norm_mode in [ 'Sign' , 'Label']:
                if i == 0:
                    x_old = x
                else: 
                    x = self.norm_scale * x_old + ( 1- self.norm_scale ) * x
                    x_old = x
            
            # if self.skip > 0 and i % self.skip==0:
            # x is update, x_old is initial x
            
            if self.norm_mode in ['res']:
                if i == 0:
                    x_old = x
                else: 
                    x = self.norm_scale * x_old + ( 1- self.norm_scale ) * x
                    x_old = x

            
            # metrics = log_metrics(metrics, x, '4', cal_erank=cal_erank) if cal_metrics else metrics
         
        
        x = self.dropout(x)
        x = self.out_layer(x, adj)
        return x, metrics
