import torch
import torch.nn as nn
from config import args

class Explainer(nn.Module):
    def __init__(self, model, **kwargs):
        super(Explainer, self).__init__(**kwargs)   # Not just super().__init__()?

        # input dims for the MLP is defined by the concatenation of the hidden layers of the GCN
        try:
            hiddens = [int(s) for s in args.hiddens.split('-')]
        except:
            hiddens =[args.hidden1]

        if args.concat:
            input_dim = sum(hiddens) * 3
        else:
            input_dim = hiddens[0] * 3
        self.device = model.device
        self.elayers = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

        self.model = model
        self.mask_act = 'sigmoid'
        # self.label = tf.argmax(tf.cast(label,tf.float32),axis=-1)
        self.params = []

        self.softmax = nn.Softmax(dim=0)

        self.coeffs = {
            "size": args.coff_size,
            "weight_decay": args.weight_decay,
            "ent": args.coff_ent
        }

    def _masked_adj(self,mask,adj):

        mask = mask.to(self.device)
        sym_mask = mask
        sym_mask = (sym_mask.clone() + sym_mask.clone().T) / 2

        # Create sparse tensor TODO: test and "maybe" a transpose is needed somewhere
        sparseadj = torch.sparse_coo_tensor(
            indices=torch.transpose(torch.cat([torch.unsqueeze(torch.Tensor(adj.row),-1), torch.unsqueeze(torch.Tensor(adj.col),-1)], dim=-1), 0, 1).to(torch.int64),
            values=adj.data,
            size=adj.shape
        )

        adj = sparseadj.coalesce().to_dense().to(torch.float32).to(self.device) #FIXME: tf.sparse.reorder was also applied, but probably not necessary. Maybe it needs a .coalesce() too tho?
        self.adj = adj

        masked_adj = torch.mul(adj,sym_mask)

        num_nodes = adj.shape[0]
        ones = torch.ones((num_nodes, num_nodes))
        diag_mask = ones.to(torch.float32) - torch.eye(num_nodes)
        diag_mask = diag_mask.to(self.device)
        return torch.mul(masked_adj,diag_mask)


    def concrete_sample(self, log_alpha, beta=1.0, training=True):
        """Uniform random numbers for the concrete distribution"""

        if training:
            bias = args.sample_bias
            random_noise = torch.FloatTensor(log_alpha.shape).uniform_(bias, 1.0-bias)
            random_noise = random_noise.to(self.device)
            gate_inputs = torch.log(random_noise) - torch.log(1.0 - random_noise)
            gate_inputs = (gate_inputs.clone() + log_alpha) / beta
            gate_inputs = torch.sigmoid(gate_inputs)
        else:
            gate_inputs = torch.sigmoid(log_alpha)

        return gate_inputs


    def forward(self,inputs,training=False):
        x,adj,nodeid, embed, tmp = inputs
        x = x.to(self.device)

        self.tmp = tmp
        if not isinstance(embed[adj.row], torch.Tensor):
            f1 = torch.tensor(embed[adj.row]).to(self.device)  # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = torch.tensor(embed[adj.col]).to(self.device)
        else:
            f1 = embed[adj.row].to(self.device)  # .to(self.device)  # <-- torch way to do tf.gather(embed, self.row)
            f2 = embed[adj.col].to(self.device)

        selfemb = embed[nodeid] if isinstance(embed, torch.Tensor) else torch.tensor(embed[nodeid])
        selfemb = torch.unsqueeze(selfemb, 0).repeat([f1.shape[0], 1]).to(self.device)
        f12self = torch.cat([f1, f2, selfemb], dim=-1)

        h = f12self
        h = h.to(self.device)
        for elayer in self.elayers:
            h = elayer(h)

        self.values = torch.reshape(h, [-1,])

        values = self.concrete_sample(self.values,beta=tmp,training=training)

        sparsemask = torch.sparse.FloatTensor(
            indices=torch.transpose(torch.cat([torch.unsqueeze(torch.tensor(adj.row),-1), torch.unsqueeze(torch.tensor(adj.col),-1)], dim=-1), 0, 1).to(torch.int64).to(self.device),
            values=values,
            size=adj.shape
        )

        mask = sparsemask.coalesce().to_dense().to(torch.float32)  #FIXME: again a reorder() is omitted, maybe coalesce
        masked_adj = self._masked_adj(mask,adj)

        self.mask = mask
        self.masked_adj = masked_adj

        output = self.model((x,masked_adj))

        node_pred = output[nodeid, :]
        res = self.softmax(node_pred)

        return res


    def loss(self, pred, pred_label, label, node_idx, adj_tensor=None):
        """
        Args:
            pred: prediction made by current model
            pred_label: the label predicted by the original model.
        """

        label = torch.argmax(label.clone().to(torch.float32), dim=-1)

        pred_label_node = pred_label[node_idx]
        logit = pred[pred_label_node]

        if args.miGroudTruth:
            gt_label_node = label[node_idx]
            logit = pred[gt_label_node]

        logit = logit + 1e-6
        pred_loss = -torch.log(logit)

        if args.budget<=0:
            size_loss = self.coeffs["size"] * torch.sum(self.mask)#len(self.mask[self.mask > 0]) #
        else:
            relu = nn.ReLU()
            size_loss = self.coeffs["size"] * relu(torch.sum(self.mask)-args.budget) #torch.sum(self.mask)
        scale=0.99
        mask = self.mask*(2*scale-1.0)+(1.0-scale)
        mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
        mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)

        l2norm = 0
        for name, parameter in self.elayers.named_parameters():
            if "weight" in name:
                l2norm = l2norm + torch.norm(parameter)
        l2norm = self.coeffs['weight_decay']*l2norm.clone()
        loss = pred_loss +size_loss+l2norm+mask_ent_loss

        # Code block for budget constraint, not used

        # if args.budget>0 and args.coff_connect>0:
        #
        #     # sample args.connect_sample adjacency pairs
        #     adj_tensor_dense = tf.sparse.to_dense(adj_tensor,validate_indices=False) # need to figure out
        #     noise = tf.random.uniform(adj_tensor_dense.shape,minval=0, maxval=0.001)
        #     adj_tensor_dense += noise
        #     cols = tf.argsort(adj_tensor_dense,direction='DESCENDING',axis=-1)
        #     sampled_rows = tf.expand_dims(tf.range(adj_tensor_dense.shape[0]),-1)
        #     sampled_cols_0 = tf.expand_dims(cols[:,0],-1)
        #     sampled_cols_1 = tf.expand_dims(cols[:,1],-1)
        #     sampled0 = tf.concat((sampled_rows,sampled_cols_0),-1)
        #     sampled1 = tf.concat((sampled_rows,sampled_cols_1),-1)
        #
        #     sample0_score = tf.gather_nd(mask,sampled0)
        #     sample1_score = tf.gather_nd(mask,sampled1)
        #     connect_loss = tf.reduce_sum(-(1.0-sample0_score)*tf.math.log(1.0-sample1_score)-sample0_score*tf.math.log(sample1_score))
        #     connect_loss = connect_loss* args.coff_connect
        #     loss += connect_loss


        return loss,pred_loss,size_loss
