from collections import OrderedDict

import torch


class CACNN(torch.nn.Module):
    def __init__(self, dims, keep_conv, l_dim):
        super(CACNN, self).__init__()
        assert isinstance(dims, list)
        assert len(dims) >= 2, f"Dimensions should contain input and output dimensions. dims: {dims}"

        def _add_layer(layers, in_chan, out_chan, activate=True):
            count = len(layers) // 2
            layers[f"conv{count}"] = torch.nn.Conv2d(in_chan, out_chan, 3, padding=1)
            if activate:
                layers[f"relu{count}"] = torch.nn.ReLU()

        d = OrderedDict()
        for idx in range(len(dims)-1):
            activate = idx<len(dims)-2 or not keep_conv
            _add_layer(d, dims[idx], dims[idx+1], activate=activate)

        self.hidden_layers = torch.nn.Sequential(d)
        self.linear_layer = torch.nn.Linear(dims[-1]*64, l_dim) if l_dim > 0 else None
        self.keep_conv = keep_conv

    def forward(self, x, out_feature=False):
        assert not out_feature
        x = self.hidden_layers(x.view(-1, 12, 8, 8))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        if self.linear_layer is None:
            return x
        o = self.linear_layer(x)
        if self.keep_conv:
            return torch.cat([x, o], dim=-1)
        return o
