##############################
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn import Linear
import torch.nn
from torch.nn import Linear, ReLU
from torch_geometric.nn import GINConv
from torch.nn import Linear, ReLU
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_scatter import scatter_min, scatter_max, scatter_add
from torch_geometric.utils import remove_self_loops,  add_remaining_self_loops, dropout_adj
from erdos.utils import get_mask
from torch_geometric.nn.norm.graph_size_norm import GraphSizeNorm

class Erdos(torch.nn.Module):
    def __init__(self, erdos_config):
        super(Erdos, self).__init__()
        self.conf = erdos_config
        self.hidden1 = self.conf.model.hidden1
        self.hidden2 = self.conf.model.hidden2
        self.momentum = self.conf.model.momentum
        self.deltas = self.conf.model.deltas
        self.numlayers = self.conf.model.numlayers
        self.heads = self.conf.model.heads
        self.concat = self.conf.model.concat
        self.device = self.conf.training.device

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        for i in range(self.numlayers-1):
            self.bns.append(BN(self.heads*self.hidden1, momentum=self.momentum))
        self.convs = torch.nn.ModuleList()        
        for i in range(self.numlayers - 1):
                self.convs.append(GINConv(Sequential(
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True))
        self.bn1 = BN(self.heads*self.hidden1)       
        self.conv1 = GINConv(Sequential(Linear(self.hidden2,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True)

        if self.concat:
            self.lin1 = Linear(self.heads*self.hidden1, self.hidden1)
        else:
            self.lin1 = Linear(self.hidden1, self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)
        self.gnorm = GraphSizeNorm()


    def reset_parameters(self):
        self.conv1.reset_parameters()
        
        for conv in self.convs:
            conv.reset_parameters() 
        for bn in self.bns:
            bn.reset_parameters()
        self.bn1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()



    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=self.device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
        pre_norm_x = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = self.device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["pre_norm_x"] = [pre_norm_x.squeeze(-1),"hist"]   #pre_norm_x
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
    
    def __repr__(self):
        return self.__class__.__name__