import torch
import sys
import argparse

from erdos.model_erdos import Erdos

class EGNORIG(torch.nn.Module):
    def __init__(self, conf, gmn_config=None):
        super(EGNORIG, self).__init__()
        self.conf = conf
        self.init_relevant_args()
        self.net = Erdos(self.conf)
        self.device = self.conf.training.device



    def init_relevant_args(self):
        if self.conf.model.EQ:
            self.conf.model.hidden1 = 10 
            self.conf.model.heads = 2
            self.conf.model.numlayers = 4
        else:
            self.conf.model.hidden1 = 64
            self.conf.model.heads = 8
            self.conf.model.numlayers = 4
         
        self.conf.model.hidden2 = 1
        self.conf.model.momentum = 0.1
        self.conf.model.deltas = 1
        self.conf.model.concat = True
        self.conf.model.transform_dim = 16
        self.conf.model.temp = 0.05
        self.conf.model.gossip_temp = 0.1
        self.conf.model.gumbel_sinkhorn_niters = 20
        self.conf.model.noise_factor = 0
        self.conf.model.use_threshold = True
        self.conf.model.use_sigmoid_score = False
        self.conf.model.neural_hard_adj = False
        self.conf.model.return_tuple = False
        self.conf.model.tensor_dim = 10
        self.conf.model.speed = True
        
        
        self.conf.training.diracs_N = 1
        self.conf.training.diracs_effective_range = 0.15
        self.conf.training.penalty_coeff = 4.0
        self.conf.dataset.max_set_size = self.conf.dataset.max_node_set_size

    def forward(self,  
                query_batch_data, #unused
                query_batch_data_node_sizes, #unused
                query_batch_data_edge_sizes, #unused
                query_batch_adj, #unused
                corpus_batch_data, 
                corpus_batch_data_node_sizes,
                corpus_batch_data_edge_sizes,  #unused
                corpus_batch_adj,
                diagnostic_mode=False):

        out = self.net(corpus_batch_data, None, penalty_coefficient=self.conf.training.penalty_coeff) #DIRACS already applied here
        for key,val in out.items():
            if "sequence" in val[1]:
                if key in out:
                    out[key][0] += val[0].item()
                else:
                    out[key] = [val[0].item(),val[1]]

        # if run > 2:
        batch_loss = out["loss"][0]

        all_out = [] 
        probs = list(torch.split(out['pre_norm_x'][0], corpus_batch_data_node_sizes))
        for i in range(len(probs)):
            if (probs[i] == 0).all():
                probs[i] = 1-probs[i]
            all_out.append(probs[i].sum())
                
        if self.training:
            return batch_loss, torch.tensor(all_out, device = self.device)
        else: 
            return out
       