import torch
import sys
import argparse
import numpy as np

from scattering.baseline_scat import Scat

def getclicnum(adjmatrix,dis,walkerstart = 0,thresholdloopnodes = 50):
    '''
    ComplementedgeM: complement matrix of adj 
    dis: distribution on the nodes, higher ->better
    cpu: cpu is usually better for small model
    '''
    _sorted, indices = torch.sort(dis.squeeze(),descending=True)#flatten, elements are sorted in descending order by value.
    initiaprd = 0.*indices  # torch zeros
    initiaprd = initiaprd.cpu().numpy() 
    for walker in range(min(thresholdloopnodes,adjmatrix.get_shape()[0])):
        if walker < walkerstart:
            initiaprd[indices[walker]] = 0.
        else:
            pass
    initiaprd[indices[walkerstart]] = 1. # the one with walkerstart'th largest prob is in the clique, start with walkerstart
    for clq in range(walkerstart+1,min(thresholdloopnodes,adjmatrix.get_shape()[0])): # loop the 50 high prob nodes
        initiaprd[indices[clq]] = 1.
        binary_vec = np.reshape(initiaprd, (-1,1)) 
        ZorO = np.sum(binary_vec)**2  - np.sum(binary_vec) - np.sum(binary_vec*(adjmatrix.dot(binary_vec)))
        if ZorO < 0.0001: # same as ZorO == 0
            pass
        else:
            initiaprd[indices[clq]] = 0.
    return np.sum(initiaprd)


class SCAT(torch.nn.Module):
    def __init__(self, conf, gmn_config=None):
        super(SCAT, self).__init__()
        self.conf = conf
        self.init_relevant_args()
        self.net = Scat(self.conf, gmn_config)
        self.device = self.conf.training.device


    def init_relevant_args(self):
        self.conf.model.use_smoo = False
        self.conf.model.smooth = 0.1
        self.conf.model.input_dim = 1
        self.conf.model.output_dim = 1
        if self.conf.model.EQ:
            self.conf.model.hidden_dim = 10
            self.conf.model.n_layers = 5
        else:
            self.conf.model.hidden_dim = 8
            self.conf.model.n_layers = 2
        self.conf.model.moment = 1
        self.conf.model.penalty_coefficient = 0.1
        self.conf.model.Numofwalkers =self.conf.model.decoder_steps
        self.conf.model.SampLength = 300
        

    def forward(self,  
                query_batch_data, 
                query_batch_data_node_sizes, 
                query_batch_data_edge_sizes, 
                query_batch_adj, 
                corpus_batch_data, 
                corpus_batch_data_node_sizes,
                corpus_batch_data_edge_sizes,  
                corpus_batch_adj,
                diagnostic_mode=False):

        if self.training:
            all_out = [] 
            batch_loss = 0
            for j in range(len(corpus_batch_data_node_sizes)):
                out, adjmatrix = self.net(corpus_batch_data[j],\
                                            corpus_batch_data_node_sizes[j],\
                                            corpus_batch_adj[j])
                losses = self.net.compute_loss(adjmatrix, out)
                batch_loss += losses
                all_out.append(out.sum())
            batch_loss = batch_loss / len(corpus_batch_data_node_sizes)
            # print(f"batch_loss: {batch_loss}")
            return batch_loss, torch.tensor(all_out, device = self.device)
        else: 
            sum_pi_cliques = []
            solver_cliques = []
            for j in range(len(corpus_batch_data_node_sizes)):
                out, adjmatrix, adj = self.net(corpus_batch_data[j],\
                                                corpus_batch_data_node_sizes[j],\
                                                corpus_batch_adj[j])
                
                if (out == 0).all():
                    out = torch.ones_like(out, device=self.conf.training.device)
                hard_indicators = out 

                hard_indicators[hard_indicators>0.5] = 1
                hard_indicators[hard_indicators<=0.5] = 0
                sum_pi_cliques.append(hard_indicators.sum().item())


                predC = []
                for walkerS in range(0,min(self.conf.model.Numofwalkers, adj.get_shape()[0])): 
                        predC += [getclicnum(adj, out, walkerstart=walkerS,thresholdloopnodes=self.conf.model.SampLength).item()]
                solver_cliques.append(max(predC))
                
            return torch.tensor(solver_cliques, device=self.device), torch.tensor(sum_pi_cliques, device=self.device)

    