import torch
import torch.nn as nn
from config import args

class Explainer(nn.Module):
    def __init__(self, model, nodesize, **kwargs):
        super(Explainer, self).__init__(**kwargs)

        # input dims for the MLP is defined by the concatenation of the hidden layers of the GCN
        try:
            hiddens = [int(s) for s in args.hiddens.split('-')]
        except:
            hiddens =[args.hidden1]
        
        if args.concat:
            input_dim = sum(hiddens) * 2 # or just times 3?
        else:
            input_dim = hiddens[-1] * 2
        self.device = model.device

        self.elayers = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        rc = torch.unsqueeze(torch.arange(0, nodesize), 0).repeat([nodesize,1]).to(torch.float32)
        # rc = torch.repeat(rc,[nodesize,1])
        self.row = torch.reshape(rc.T,[-1]).to(self.device)
        self.col = torch.reshape(rc,[-1]).to(self.device)
        # For masking diagonal entries
        self.nodesize = nodesize
        self.model = model
        self.softmax = nn.Softmax(dim=-1)

        ones = torch.ones((nodesize, nodesize))
        self.diag_mask = ones.to(torch.float32) - torch.eye(nodesize)

        self.mask_act = 'sigmoid'

    def concrete_sample(self, log_alpha, beta=1.0, training=True):
        """Uniform random numbers for the concrete distribution"""

        if training:
            debug_var = 0.0
            bias = 0.0
            random_noise = bias + torch.FloatTensor(log_alpha.shape).uniform_(debug_var, 1.0-debug_var)
            random_noise = random_noise.to(self.device)
            gate_inputs = torch.log(random_noise) - torch.log(1.0 - random_noise)
            gate_inputs = (gate_inputs.clone() + log_alpha) / beta
            gate_inputs = torch.sigmoid(gate_inputs)
        else:
            gate_inputs = torch.sigmoid(log_alpha)

        return gate_inputs


    def forward(self,inputs,training=None):
        x, embed, adj,tmp, label = inputs
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        x = x.to(self.device)
        if not isinstance(adj, torch.Tensor):
            adj = torch.tensor(adj)
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label)
        adj = adj.to(self.device)
        # embed = embed.to('cpu')
        self.label = torch.argmax(label.to(torch.float32), dim=-1)
        self.tmp = tmp
        row = self.row.type(torch.LongTensor).to(self.device)#('cpu')
        col = self.col.type(torch.LongTensor).to(self.device)
        if not isinstance(embed[row], torch.Tensor):
            f1 = torch.Tensor(embed[row]).to(self.device)   # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = torch.Tensor(embed[col]).to(self.device)
        else:
            f1 = embed[row]  # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = embed[col]
        h = torch.cat([f1, f2], dim=-1)
        h = h.to(self.device)
        for elayer in self.elayers:
            h = elayer(h)

        self.values = torch.reshape(h, [-1])
        values = self.concrete_sample(self.values, beta=tmp, training=training)
        sparsemask = torch.sparse.FloatTensor(
            indices=torch.transpose(torch.cat([torch.unsqueeze(self.row, -1), torch.unsqueeze(self.col,-1)], dim=-1), 0, 1).to(torch.int64),
            values=values,
            size=[self.nodesize,self.nodesize]
        ).to(self.device)
        sym_mask = sparsemask.coalesce().to_dense().to(torch.float32)  #FIXME: again a reorder() is omitted, maybe coalesce
        self.mask = sym_mask

        # sym_mask = (sym_mask.clone() + sym_mask.clone().T) / 2      # Maybe needs a .clone()
        sym_mask = (sym_mask + sym_mask.T) / 2
        masked_adj = torch.mul(adj, sym_mask)
        self.masked_adj = masked_adj
        x = torch.unsqueeze(x.detach().requires_grad_(True),0).to(torch.float32)        # Maybe needs a .clone()
        adj = torch.unsqueeze(self.masked_adj,0).to(torch.float32)
        x.to(self.device)
        output = self.model((x,adj))
        res = self.softmax(output)
        
        return res


    def loss(self, pred, pred_label):
        """
        Args:
            pred: prediction made by current model
            pred_label: the label predicted by the original model.
        """
        pred_reduce = pred[0]
        gt_label_node = self.label
        logit = pred_reduce[gt_label_node]
        pred_loss = -torch.log(logit)
        # size
        mask = self.mask
        if self.mask_act == "sigmoid":
            mask = torch.sigmoid(self.mask)
        elif self.mask_act == "ReLU":
            mask = nn.functional.relu(self.mask)
        size_loss = args.coff_size * torch.sum(mask) #len(mask[mask > 0]) #torch.sum(mask)

        # entropy
        mask = mask *0.99+0.005     # maybe a .clone()
        mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
        mask_ent_loss = args.coff_ent * torch.mean(mask_ent)


        loss = pred_loss + size_loss + mask_ent_loss

        return loss
