from abc import abstractmethod

import torch
import torch.nn as nn
from torch_geometric.nn import TopKPooling
from gcip.modules.gnn.nn.my_topk_pool import MyTopKPooling
from gcip.utils.graph import remove_nodes_from_batch_mask
import gcip.utils.io as pb_io

class TopKGNN(nn.Module):
    def __init__(self,
                 gnn,
                 ratio=0.5,
                 min_score=None,
                 multiplier=1.0
                 ):
        super(TopKGNN, self).__init__()

        if isinstance(gnn, int):
            self.gnn = None
            in_channels=gnn
        else:
            self.gnn = gnn
            in_channels=gnn.output_size

        self.ratio = ratio
        if isinstance(min_score, float) and min_score == 0.0:
            min_score = None
        self.min_score = min_score
        self.multiplier = multiplier

        self.pool = MyTopKPooling(in_channels=in_channels,
                                  ratio=self.ratio,
                                  min_score=self.min_score,
                                  multiplier=self.multiplier,
                                  nonlinearity=torch.tanh)

    @staticmethod
    def kwargs(cfg, gnn):
        my_dict = {}
        my_dict['gnn'] = gnn

        my_dict['ratio'] = cfg.model.ratio
        my_dict['min_score'] = cfg.model.min_score
        my_dict['multiplier'] = cfg.model.multiplier
        return my_dict

    @abstractmethod
    def forward(self, batch, inplace=False, **kwargs):

        if self.gnn is None:
            logits = None
        else:
            logits = self.gnn(batch.clone())

        if inplace:
            batch_ = batch
        else:
            batch_ = batch.clone()
        mask, score = self.pool(x=batch_.x,
                                edge_index=batch_.edge_index,
                                batch=batch_.batch,
                                edge_attr=batch_.edge_attr,
                                attn=logits)

        # playbook_io.print_debug_tensor(score, 'score')
        # playbook_io.print_debug_tensor(self.pool.weight.data, 'pool.weight')
        batch_ = remove_nodes_from_batch_mask(mask=mask,
                                              mode='keep',
                                              batch=batch_.clone(),
                                              relabel_nodes=True)

        batch_.x = batch_.x*score.view(-1, 1)
        return batch_
