import torch
import torch.nn as nn
import torch.nn.functional as F

from graph_learning.module import Module, ModuleConfig, register_module, get_module
from graph_learning.utils import dict_merge_rec
from graph_learning.dataset.graph import gl_unbatch

@ModuleConfig.register('pairwise-node-cls-decoder',
                       help='[Decoder] pairwise node classification.')
class PairNodeClassificationDecoderModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

        self.device = context.global_.device

        self.classifier = get_module(context, self.classifier)

    @property
    def builder(self):
        return PairNodeClassificationDecoder

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--classifier',
                            help='classification layer, map hidden representation from graph encoder to logits.')

        parser.add_argument('--out-layer', choices=['cos', 'fc', 'reg'], default='cos',
                            help='output computing: cosine similarity / fully connected layer')
        parser.add_argument('--input-size', type=int)
        parser.add_argument('--dropout', type=float)

class PairNodeClassificationDecoder(nn.Module):
    def __init__(self, classifier,
                 out_layer,
                 input_size,
                 dropout):
        super().__init__()

        self.classifier = classifier

        self.out_layer_model = out_layer

        if out_layer == 'cos':
            self.loss_func = nn.BCEWithLogitsLoss()
        elif out_layer == 'fc':
            self.loss_func = nn.CrossEntropyLoss()
            self.out_layer = nn.Sequential(
                nn.Linear(3*input_size, input_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(input_size, 2))
        elif out_layer == 'reg':
            self.loss_func = nn.MSELoss()
            self.out_layer = nn.Linear(3*input_size, 1)

    def forward(self, data, hidden):
        try:
            return self._forward(data, hidden)
        except:
            print(hidden.shape)
            raise

    def _forward(self, data, hidden):
        if self.classifier is not None:
            out = self.classifier(hidden)
        else:
            out = hidden
        if isinstance(out, tuple):
            out = out[0]

        if self.out_layer_model == 'cos':
            out = F.normalize(out, p=2, dim=-1)

        loss_gs = []

        if self.out_layer_model in ['cos', 'fc']:
            pair_pos = data.pair_pos()
            pair_neg = data.pair_neg()
            pair_mask = torch.cat([pair_pos, pair_neg], -1)
        elif self.out_layer_model == 'reg':
            pair_mask = data.pair()

        nodes_first = torch.index_select(out, 0, pair_mask[0, :])
        nodes_second = torch.index_select(out, 0, pair_mask[1, :])
        if self.out_layer_model == 'cos':
            out = torch.sum(nodes_first*nodes_second, dim=-1)
            prob = F.sigmoid(logits)
        elif self.out_layer_model in ['fc', 'reg']:
            x_diff = torch.abs(nodes_first - nodes_second)
            x_mean = (nodes_first + nodes_second) / 2
            x_max = torch.max(nodes_first, nodes_second)
            x = torch.cat([x_diff, x_mean, x_max], -1)

            if self.out_layer_model == 'fc':
                out = self.out_layer(x)
                prob = F.softmax(out, -1)
                outputs = {'probs': prob[:, 1]}
            elif self.out_layer_model == 'reg':
                out = self.out_layer(x).squeeze(-1)


        if self.out_layer_model in ['cos', 'fc']:
            label_pos = torch.ones([pair_pos.shape[1],], dtype=torch.long, device=out.device)
            label_neg = torch.zeros([pair_neg.shape[1],], dtype=torch.long, device=out.device)
            label = torch.cat([label_pos, label_neg])
        elif self.out_layer_model == 'reg':
            label = data.pair_labels()

        loss = self.loss_func(out, label)
        outputs.update({'labels': label})

        losses = {'pair_cls': loss}
        ret = {'loss': loss, 'losses': losses,
               'outputs': outputs}
        return ret
