from torch import nn
from .layers import EFATConv, EFATPool, EFATDense
import torchvision

class Conv2dSame(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=None,
        dilation=1,
        groups=1,
        bias=True
    ):
        if padding is None:
            try:
                padding = {1: 0, 3: 1, 5: 2, 7: 3}[kernel_size]
            except KeyError:
                raise ValueError(
                    f'Unsupported padding for kernel size {kernel_size}.'
                )

        super().__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
        )


class Block(nn.Module):
    def __init__(
            self,
            inputs,
            outputs,
            k,
            stride,
            dropout,
            bias=False,
            inplace=False,
            EFAT_kwargs={}
    ):
        super(Block, self).__init__()

        self.conv = EFATConv(
            inputs,
            outputs,
            Conv2dSame,
            {
                'kernel_size':k,
                'stride':stride,
                'bias':bias
            },
            **EFAT_kwargs,
        )
        self.norm = nn.BatchNorm2d(outputs)
        self.act = nn.ReLU(inplace)

        self.dropout = dropout
        if self.dropout:
            self.drop = nn.Dropout(p=0.2)

    def forward(self, batch):
        x, edge_index = batch

        out = self.conv(batch)
        out = self.norm(out)
        out = self.act(out)

        if self.dropout:
            out = self.drop(out)
        return out, edge_index


class CifarNetEFG(nn.Module):
        name = 'cifarnet'
#         filters = [64, 64, 128, 128, 128, 192, 192, 192]
#         kernels = [3, 3, 3, 3, 3, 3, 3, 3]
#         strides = [2, 2, 2, 1, 2, 1, 2, 1]
#         dropout = [True, True, True, True, True, True, True, True]
        filters = [32, 32, 64, 64, 80, 80]
        kernels = [3, 3, 3, 3, 3, 3]
        strides = [2, 2, 2, 1, 2, 2]
        dropout = [True, True, True, True, True, True]

        def __init__(
                self,
                config, # Just to keep consistent
                num_classes
        ):
            super(CifarNetEFG, self).__init__()
            inputs = 3
            iterer = zip(self.kernels, self.filters, self.strides, self.dropout)
            outputs = None
            layers = []

            for k, outputs, stride, dropout in iterer:
                layers.append(Block(inputs, outputs, k, stride, dropout))
                inputs = outputs

            self.layers = nn.Sequential(*layers)
            print(self.layers)
            # classifier
            self.pool = EFATPool(
                outputs, nn.AdaptiveAvgPool2d, {'output_size': 1}
            )
            self.fc = EFATDense(
                outputs, num_classes, nn.Linear, {}
            )

        def forward(self, batch):
#             resized = torchvision.transforms.Resize((224,224))(batch.x)
            x = (batch.x, batch.edge_index)
            extracted = self.layers(x)
            pooled = self.pool(extracted)
            return self.fc((pooled.squeeze(-1).squeeze(-1), batch.edge_index))

