#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : mee.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

import dgl
from dgl import DGLGraph

from megraph.utils import check_len, residual_when_same_shape_on_list
from megraph.torch_utils import get_index_of_fused_groups
from megraph.representation import MultiFeatures

__all__ = ["MeeLayer", "MeeLayerV1"]


class MeeLayer(nn.Module):
    def __init__(
        self,
        intra_convs,
        inter_convs,
        pooling=None,
        inter_fusion_fcs=None,
        cross_update_method="conv",
        start_height=0,
        end_height=None,
        vertical_first=False,
        stem_beta=1.0,
        branch_beta=0.5,
        keep_beta=1.0,
        cross_beta=1.0,
        dropout=None,
        dropout_after_residual=False,
    ):
        super(MeeLayer, self).__init__()
        self.intra_convs = intra_convs
        if cross_update_method == "pool":
            inter_convs = [None] * len(inter_convs)
        if cross_update_method == "conv":
            pooling = None
        self.inter_convs = inter_convs
        self.pooling = pooling
        self.inter_fusion_fcs = inter_fusion_fcs
        self.cross_update_method = cross_update_method
        self.start_height = start_height
        self.end_height = end_height
        self.vertical_first = vertical_first
        self.stem_beta = stem_beta
        self.branch_beta = branch_beta
        self.keep_beta = keep_beta
        self.cross_beta = cross_beta
        self.dropout = dropout
        self.dropout_after_residual = dropout_after_residual

    def get_conv(self, convs, i, height):
        if i < 0 or i >= height:
            raise ValueError("index must be in [0, height)")
        return convs[i * len(convs) // height]

    def get_fusion_fc(self, fusion_fcs, i, height):
        if i < 0 or i >= height:
            raise ValueError("index must be in [0, height)")
        if fusion_fcs is None:
            return None
        return fusion_fcs[i * len(fusion_fcs) // height]

    def _cross_update(
        self,
        conv,
        fusion_fc,
        xg: DGLGraph,
        yg: DGLGraph,
        inter_g: DGLGraph,
        cluster: torch.Tensor,  # from larger graph to smaller graph
        x: MultiFeatures,
        y: MultiFeatures,
    ):
        nx = x.nodes_features
        ny = y.nodes_features

        # "combine" means new_ny is obtained by conv, new_nx is obtained by unpool
        if self.cross_update_method in ["conv", "combine"]:
            # Compute new feature by inter conv
            x_ns, y_ns = xg.batch_num_nodes(), yg.batch_num_nodes()
            x_idx, y_idx = get_index_of_fused_groups(x_ns, y_ns, device=nx.device)
            cat_xy = torch.cat([nx, ny], dim=0)
            nz = torch.zeros_like(cat_xy)
            nz[x_idx] = nx
            nz[y_idx] = ny

            z = MultiFeatures([None, nz, None])
            z = conv(inter_g, z)
            nz = z.nodes_features
            new_nx = nz[x_idx]
            new_ny = nz[y_idx]
        if self.cross_update_method in ["pool", "combine"]:
            # Use unpool instead of inter conv
            new_nx = self.pooling.unpool(ny, yg, inter_g, cluster)
        if self.cross_update_method in ["pool"]:
            # Use pool instead of inter conv
            new_ny = self.pooling.pool(nx, yg, cluster)
        # Fuse old and new feature
        if fusion_fc is None:
            x = x.replace_nodes_features(nx * self.keep_beta + new_nx * self.cross_beta)
            y = y.replace_nodes_features(ny * self.keep_beta + new_ny * self.cross_beta)
        else:

            def fuse(xs):
                w = fusion_fc(xs)
                return sum([x * w[:, i : i + 1] for i, x in enumerate(xs)])

            x = x.replace_nodes_features(fuse([nx, new_nx]))
            y = y.replace_nodes_features(fuse([ny, new_ny]))
        return x, y

    def _rightward_update(self, height, inds, intra_graphs, xs):
        res = []
        for i in range(height):
            x = xs[i]
            if inds[0] <= i <= inds[-1]:
                x = self.get_conv(self.intra_convs, i, height)(intra_graphs[i], x)
            res.append(x)
        return res

    def _vertical_update(self, height, inds, intra_graphs, inter_graphs, clusters, xs):
        # inplace update
        for i in inds:
            if i + 1 >= height:
                continue
            xs[i], xs[i + 1] = self._cross_update(
                self.get_conv(self.inter_convs, i, height - 1),
                self.get_fusion_fc(self.inter_fusion_fcs, i, height),
                intra_graphs[i],
                intra_graphs[i + 1],
                inter_graphs[i],
                clusters[i],
                xs[i],
                xs[i + 1],
            )

    def _residual(self, xs, ys, stem_beta=None, branch_beta=None):
        if stem_beta is None:
            stem_beta = self.stem_beta
        if branch_beta is None:
            branch_beta = self.branch_beta
        return [
            x.residual_when_same_shape(y, stem_beta, branch_beta)
            for x, y in zip(xs, ys)
        ]

    def forward(self, height, intra_graphs, inter_graphs, clusters, features):
        check_len(intra_graphs, height)
        check_len(inter_graphs, height - 1)
        check_len(clusters, height - 1)
        check_len(features, height)
        sh = self.start_height
        eh = height - 1 if self.end_height is None else self.end_height
        xs = features

        if not self.vertical_first:
            xs = self._rightward_update(height, [sh, eh], intra_graphs, xs)

        # xs = self._residual(features, xs)
        # Vertical downward and upward updates
        inds = list(range(sh, eh))
        self._vertical_update(height, inds, intra_graphs, inter_graphs, clusters, xs)
        if len(inds) > 1:  # only do second round when there are more than one heights.
            inds = inds[::-1]
            self._vertical_update(
                height, inds, intra_graphs, inter_graphs, clusters, xs
            )

        if self.vertical_first:
            xs = self._rightward_update(height, [sh, eh], intra_graphs, xs)
        if not self.dropout_after_residual:
            xs = [x.apply_fn(self.dropout) for x in xs]
        xs = self._residual(features, xs)
        if self.dropout_after_residual:
            xs = [x.apply_fn(self.dropout) for x in xs]
        return xs


# Belows to be deprecated
class MeeLayerV1(nn.Module):
    def __init__(
        self,
        intra_conv,
        inter_conv,
        n_hidden,
        branch_beta=0.5,
        act=None,
        dropout=None,
        use_residual=True,
        enable_bn=True,
    ):
        super(MeeLayer, self).__init__()
        self.intra_conv = intra_conv
        self.inter_conv = inter_conv
        self.act = act
        self.branch_beta = branch_beta
        self.dropout = nn.Dropout(p=dropout) if dropout is not None else None
        self.use_residual = use_residual
        self.intra_bn = nn.BatchNorm1d(n_hidden) if enable_bn else None
        self.inter_bn = nn.BatchNorm1d(n_hidden) if enable_bn else None

    def _cross_update(self, g, x, y, n):
        z = torch.cat([x, y], dim=0)
        z = self.inter_conv(g, z)
        if self.inter_bn:
            z = self.inter_bn(z)
        if self.act is not None:
            z = self.act(z)
        return x + z[:n] * self.branch_beta, y + z[n:] * self.branch_beta

    def _rightward_update(self, intra_graphs, xs):
        ys = []
        for x, g in zip(xs, intra_graphs):
            y = self.intra_conv(g, x)
            if self.intra_bn:
                y = self.intra_bn(y)
            if self.act is not None:
                y = self.act(y)
            ys.append(y)
        return ys

    def _vertical_update(self, inds, inter_graphs, num_nodes, xs):
        for i in inds:
            xs[i], xs[i + 1] = self._cross_update(
                inter_graphs[i], xs[i], xs[i + 1], num_nodes[i]
            )
        return xs

    def _residual(self, xs, ys):
        return residual_when_same_shape_on_list(
            xs, ys, x_beta=self.stem_beta, y_beta=self.branch_beta
        )

    # This is Mee V1, to be removed
    def forward(self, height, num_nodes, intra_graphs, inter_graphs, features):
        check_len(num_nodes, height)
        check_len(intra_graphs, height)
        check_len(inter_graphs, height - 1)
        check_len(features, height)
        xs = features
        xs = self._rightward_update(intra_graphs, xs)
        inds = list(range(height - 1))
        xs = self._vertical_update(inds, inter_graphs, num_nodes, xs)
        inds = inds[::-1]
        xs = self._vertical_update(inds, inter_graphs, num_nodes, xs)
        if self.use_residual:
            xs = self._residual(features, xs)
        if self.dropout is not None:
            xs = [self.dropout(x) for x in xs]
        return xs
