import torch
import torch.nn as nn
import torch.nn.functional as F
from graph_learning.module.modules.layers.classifier import MLP
from graph_learning.module import ModuleConfig, get_module
from graph_learning.utils import merge_dicts, dict_merge_rec

def prepare_hiddens(xs):
    xs_ret, masks_ret = [], []
    for x in xs:
        if isinstance(x, list):
            x, mask = prepare_hiddens(x)
        else:
            mask = None
            if isinstance(x, tuple):
                if len(x) == 2:
                    x, mask = x
                elif len(x) == 3:
                    x, _, mask = x
            if mask is None:
                mask = torch.ones(x.size(0), dtype=torch.bool, device=x.device)
            if len(mask.shape)==1:
                mask = mask.unsqueeze(1)

        if isinstance(x, list):
            xs_ret = xs_ret + x
        else:
            xs_ret.append(x)

        masks_ret.append(mask)
    mask = torch.cat(masks_ret, -1)
    return xs_ret, mask

class FusionModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.encoders = [get_module(context, n)
                         for n in self.encoders]

    @property
    def builder(self):
        return FusionModule

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--encoders', nargs='+',
                            help='sub-encoders')
        parser.add_argument('--freeze', action='store_true',
                            help='freeze sub-encoders')

class FusionModule(nn.Module):
    def __init__(self, encoders, freeze=False):
        super().__init__()
        self.encoders = nn.ModuleList(encoders)
        if freeze:
            for p in self.encoders.parameters():
                p.requires_grad = False

    def fusion(self, *xs):
        raise NotImplementedError

    def forward(self, data, feature):
        encoder_outputs = [encoder(data, feature)
                           for encoder in self.encoders]
        xs = [eo['hidden'] for eo in encoder_outputs]
        outputs = merge_dicts([eo['outputs'] for eo in encoder_outputs])
        fusion_outputs = self.fusion(*xs)
        hidden = fusion_outputs['hidden']
        outputs = dict_merge_rec(outputs, fusion_outputs['outputs'])
        return {'hidden': hidden,
                'outputs': outputs}

@ModuleConfig.register('plain-fusion',
                       help='[Encoder] plain fusion layer')
class PlainFusionConfig(FusionModuleConfig):
    @property
    def builder(self):
        return PlainFusion

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--typ', choices=['cat'],
                            help='fusion strategy, support plain concatenate now')
        parser.add_argument('--aux-hiddens', action='store_true',
                            help='return also hidden representatios before fusion')

class PlainFusion(FusionModule):
    def __init__(self, encoders, typ, aux_hiddens, freeze):
        super().__init__(encoders, freeze)
        self.typ = typ
        self.aux_hiddens = aux_hiddens

    def fusion(self, *xs):
        xs, mask = prepare_hiddens(xs)

        if self.typ == 'cat':
            h = torch.cat(xs, -1)
        mask_f = torch.ones(mask.size(0), 1, dtype=torch.bool, device=mask.device)
        hidden_ret = ([h, *xs], None, torch.cat([mask_f, mask], -1)) if self.aux_hiddens else h

        return {'hidden': hidden_ret,
                'outputs': {}}

@ModuleConfig.register('moe-weighting',
                       help='[Encoder] compute MoE weights.')
class MoEWModuleConfig(FusionModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return MoEW

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--in-sizes', type=int, nargs='+')
        parser.add_argument('--hidden-size', type=int)

class MoEW(FusionModule):
    def __init__(self, encoders,
                 in_sizes, hidden_size, freeze):
        super().__init__(encoders)
        if freeze:
            for p in self.encoders.parameters():
                p.requires_grad = False

        k = len(in_sizes)

        self.fc_gs = nn.ModuleList()
        for i in range(k):
            self.fc_gs.append(nn.Linear(in_sizes[i], k))

    def fusion(self, *xs):
        xs, mask = prepare_hiddens(xs)
        assert len(xs) == len(self.fc_gs)
        gs = [fc_g(x) for x, fc_g in zip(xs, self.fc_gs)]
        gate = sum(gs)
        gate[~mask] = float('-inf')

        gate = F.softmax(gate * 10, dim=-1)

        hidden_ret = (xs, gate, mask)

        return {'hidden': hidden_ret,
                'outputs': {}}
