from torch import nn
import torch.nn.functional as F
from grp_modules.model.resnet.blocks.utils import LambdaLayer  # pylint: disable=relative-beyond-top-level
import torch


class BasicBlockPost(nn.Module):

    def __init__(self, state, event, in_planes, planes, *args, stride=1, **kwargs):
        del args, kwargs  # unused
        super().__init__()
        self.event = event
        self.state = state

        # averaging in feature dimension

        self.conv0 = nn.Conv2d(in_planes, planes * state["avg_factor"], kernel_size=3, stride=stride, padding=1,
                               bias=False)
        self.norm0 = event.normalization_layer(planes * state["avg_factor"])

        assert state["block_depth"] >= 2
        self.middle_acts = nn.ModuleList(
            [event.activation_layer() for i in range(state["block_depth"] - 1)])
        self.middle_convs = nn.ModuleList([nn.Conv2d(
            planes if state["reduce_everytime"] else planes * state["avg_factor"], planes * state["avg_factor"],
            kernel_size=3, stride=1, padding=1, groups=1 if state["reduce_everytime"] else state["avg_factor"],
            bias=False) for i in range(state["block_depth"] - 1)])
        self.middle_norms = nn.ModuleList(
            [event.normalization_layer(planes * state["avg_factor"]) for i in range(state["block_depth"] - 1)])
        self.act1 = self.relu = event.activation_layer()

        if stride != 1 or in_planes != planes or state["avg_factor"] > 1:

            # Cifar-10 ResNet paper uses option A.
            if self.state["option"] == 'A':
                self.shortcut = LambdaLayer(
                    lambda x: F.pad(x[:, :, ::2, ::2], (
                        0, 0, 0, 0, (planes + state["avg_factor"]) // 4, (planes + state["avg_factor"]) // 4),
                                    "constant",
                                    0))
            elif self.state["option"] == 'B':
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_planes, state["expansion"] * planes * state["avg_factor"], kernel_size=1,
                              stride=stride, bias=False),
                    nn.BatchNorm2d(state["expansion"] * planes * state["avg_factor"])
                )
        else:
            self.shortcut = nn.Sequential()
        if state["reduce"] == "wsum" and state["avg_factor"] > 1:
            self.reduce_weights = nn.Parameter(torch.ones(size=(state["block_depth"], state["avg_factor"])),
                                               requires_grad=True)
        else:
            self.reduce_weights = [None for i in range(len(self.middle_norms))]

    def forward(self, x):
        out = self.conv0(x)
        out = self.norm0(out)

        for act, conv, norm, reduce_weight in zip(self.middle_acts, self.middle_convs, self.middle_norms,
                                                  self.reduce_weights):
            self.event.optional.before_relu(out)
            out = act(out)
            if self.state["reduce_everytime"]:
                out = self.event.optional.reduce(out, reduce_weight)
            out = conv(out)
            out = norm(out)

        if self.state["short"]:
            sc = self.shortcut(x)
            self.event.optional.before_addition(out, sc)
            out = out + sc

        out = self.event.optional.reduce(out, self.reduce_weights[-1] if (self.state["reduce"] == "wsum"
                                                                          and self.state["avg_factor"] > 1) else None)
        self.event.optional.before_relu(out)
        out = self.act1(out)
        return out


def reduce(state, event, out, reduce_weight):
    if state["avg_factor"] > 1:
        out = torch.stack(out.split(dim=1, split_size=int(out.shape[1] / state["avg_factor"])), -1)
        if state["reduce"] == "sum":
            out = out.sum(-1)
        elif state["reduce"] == "wsum":
            reduce_weight = reduce_weight / reduce_weight.sum()
            out = out.mul(reduce_weight.view((1, 1, 1, 1, -1)))
            if state["step"] % state["plot.steps"] == 0:
                event.optional.plot_scalar2d(reduce_weight, title="Reduce weights")
            out = out.sum(-1)
        else:
            return ValueError("Reduce not known")
    return out


def residual_factor(state, out, sc):
    if state["residual_factor"] != 1:
        sc *= state["residual_factor"]


def pre_resblockchain(state, event):
    return (
        nn.Conv2d(state.all["dataset.num_channels"], state["first_conv_filters"],
                  kernel_size=state["first_conv_kernel_size"],
                  stride=state["first_conv_stride"], padding=state["first_conv_padding"], bias=False),
        event.normalization_layer(state["first_conv_filters"]),
        event.activation_layer())


def post_resblockchain(in_planes):
    del in_planes  # unused
    return []


def register(mf):
    mf.set_scope("model.resnet")
    mf.register_defaults({
        "short": True,
        "concat": False,
        "option": "A",
        "expansion": 1,
        "block_depth": 2,
        "avg_factor": 1,
        "residual_factor": 1,
        "reduce": "sum",
        "reduce_everytime": True,

        # first conv filter
        "first_conv_filters": 16,
        "first_conv_kernel_size": 3,
        "first_conv_padding": 1,
        "first_conv_stride": 1,
        "first_conv_max_pool": False,
    })

    mf.register_event("resblock", BasicBlockPost, unique=True)
    mf.register_event("pre_resblockchain", pre_resblockchain, unique=True)
    mf.register_event("post_resblockchain", post_resblockchain, unique=True)
    mf.register_event("reduce", reduce, unique=True)
    mf.register_event("before_addition", residual_factor)
