import torch
import torch.nn as nn
from torchvision.models import inception
from .nets_utils import EmbeddingRecorder


class BasicConv2d(nn.Module):

    def __init__(self, input_channels, output_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x


# same naive inception module
class InceptionA(nn.Module):

    def __init__(self, input_channels, pool_features):
        super().__init__()
        self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1)

        self.branch5x5 = nn.Sequential(
            BasicConv2d(input_channels, 48, kernel_size=1),
            BasicConv2d(48, 64, kernel_size=5, padding=2)
        )

        self.branch3x3 = nn.Sequential(
            BasicConv2d(input_channels, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3, padding=1),
            BasicConv2d(96, 96, kernel_size=3, padding=1)
        )

        self.branchpool = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1)
        )

    def forward(self, x):
        # x -> 1x1(same)
        branch1x1 = self.branch1x1(x)

        # x -> 1x1 -> 5x5(same)
        branch5x5 = self.branch5x5(x)
        # branch5x5 = self.branch5x5_2(branch5x5)

        # x -> 1x1 -> 3x3 -> 3x3(same)
        branch3x3 = self.branch3x3(x)

        # x -> pool -> 1x1(same)
        branchpool = self.branchpool(x)

        outputs = [branch1x1, branch5x5, branch3x3, branchpool]

        return torch.cat(outputs, 1)


# downsample
# Factorization into smaller convolutions
class InceptionB(nn.Module):

    def __init__(self, input_channels):
        super().__init__()

        self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2)

        self.branch3x3stack = nn.Sequential(
            BasicConv2d(input_channels, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3, padding=1),
            BasicConv2d(96, 96, kernel_size=3, stride=2)
        )

        self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        # x - > 3x3(downsample)
        branch3x3 = self.branch3x3(x)

        # x -> 3x3 -> 3x3(downsample)
        branch3x3stack = self.branch3x3stack(x)

        # x -> avgpool(downsample)
        branchpool = self.branchpool(x)

        # """We can use two parallel stride 2 blocks: P and C. P is a pooling
        # layer (either average or maximum pooling) the activation, both of
        # them are stride 2 the filter banks of which are concatenated as in
        # figure 10."""
        outputs = [branch3x3, branch3x3stack, branchpool]

        return torch.cat(outputs, 1)


# Factorizing Convolutions with Large Filter Size
class InceptionC(nn.Module):
    def __init__(self, input_channels, channels_7x7):
        super().__init__()
        self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1)

        c7 = channels_7x7

        # In theory, we could go even further and argue that one can replace any n × n
        # convolution by a 1 × n convolution followed by a n × 1 convolution and the
        # computational cost saving increases dramatically as n grows (see figure 6).
        self.branch7x7 = nn.Sequential(
            BasicConv2d(input_channels, c7, kernel_size=1),
            BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
        )

        self.branch7x7stack = nn.Sequential(
            BasicConv2d(input_channels, c7, kernel_size=1),
            BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
        )

        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(input_channels, 192, kernel_size=1),
        )

    def forward(self, x):
        # x -> 1x1(same)
        branch1x1 = self.branch1x1(x)

        # x -> 1layer 1*7 and 7*1 (same)
        branch7x7 = self.branch7x7(x)

        # x-> 2layer 1*7 and 7*1(same)
        branch7x7stack = self.branch7x7stack(x)

        # x-> avgpool (same)
        branchpool = self.branch_pool(x)

        outputs = [branch1x1, branch7x7, branch7x7stack, branchpool]

        return torch.cat(outputs, 1)


class InceptionD(nn.Module):

    def __init__(self, input_channels):
        super().__init__()

        self.branch3x3 = nn.Sequential(
            BasicConv2d(input_channels, 192, kernel_size=1),
            BasicConv2d(192, 320, kernel_size=3, stride=2)
        )

        self.branch7x7 = nn.Sequential(
            BasicConv2d(input_channels, 192, kernel_size=1),
            BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(192, 192, kernel_size=3, stride=2)
        )

        self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        # x -> 1x1 -> 3x3(downsample)
        branch3x3 = self.branch3x3(x)

        # x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
        branch7x7 = self.branch7x7(x)

        # x -> avgpool (downsample)
        branchpool = self.branchpool(x)

        outputs = [branch3x3, branch7x7, branchpool]

        return torch.cat(outputs, 1)


# same
class InceptionE(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1)

        self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1)
        self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(input_channels, 192, kernel_size=1)
        )

    def forward(self, x):
        # x -> 1x1 (same)
        branch1x1 = self.branch1x1(x)

        # x -> 1x1 -> 3x1
        # x -> 1x1 -> 1x3
        # concatenate(3x1, 1x3)
        # """7. Inception modules with expanded the filter bank outputs.
        # This architecture is used on the coarsest (8 × 8) grids to promote
        # high dimensional representations, as suggested by principle
        # 2 of Section 2."""
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3)
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        # x -> 1x1 -> 3x3 -> 1x3
        # x -> 1x1 -> 3x3 -> 3x1
        # concatenate(1x3, 3x1)
        branch3x3stack = self.branch3x3stack_1(x)
        branch3x3stack = self.branch3x3stack_2(branch3x3stack)
        branch3x3stack = [
            self.branch3x3stack_3a(branch3x3stack),
            self.branch3x3stack_3b(branch3x3stack)
        ]
        branch3x3stack = torch.cat(branch3x3stack, 1)

        branchpool = self.branch_pool(x)

        outputs = [branch1x1, branch3x3, branch3x3stack, branchpool]

        return torch.cat(outputs, 1)


class InceptionV3_32x32(nn.Module):

    def __init__(self, channel, num_classes, record_embedding=False, no_grad=False):
        super().__init__()
        self.Conv2d_1a_3x3 = BasicConv2d(channel, 32, kernel_size=3, padding=3 if channel == 1 else 1)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)

        # naive inception module
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)

        # downsample
        self.Mixed_6a = InceptionB(288)

        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)

        # downsample
        self.Mixed_7a = InceptionD(768)

        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)

        # 6*6 feature size
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout2d()
        self.linear = nn.Linear(2048, num_classes)

        self.embedding_recorder = EmbeddingRecorder(record_embedding)
        self.no_grad = no_grad

    def get_last_layer(self):
        return self.linear

    def forward(self, x):
        with torch.set_grad_enabled(not self.no_grad):
            # 32 -> 30
            x = self.Conv2d_1a_3x3(x)
            x = self.Conv2d_2a_3x3(x)
            x = self.Conv2d_2b_3x3(x)
            x = self.Conv2d_3b_1x1(x)
            x = self.Conv2d_4a_3x3(x)

            # 30 -> 30
            x = self.Mixed_5b(x)
            x = self.Mixed_5c(x)
            x = self.Mixed_5d(x)

            # 30 -> 14
            # Efficient Grid Size Reduction to avoid representation
            # bottleneck
            x = self.Mixed_6a(x)

            # 14 -> 14
            # """In practice, we have found that employing this factorization does not
            # work well on early layers, but it gives very good results on medium
            # grid-sizes (On m × m feature maps, where m ranges between 12 and 20).
            # On that level, very good results can be achieved by using 1 × 7 convolutions
            # followed by 7 × 1 convolutions."""
            x = self.Mixed_6b(x)
            x = self.Mixed_6c(x)
            x = self.Mixed_6d(x)
            x = self.Mixed_6e(x)

            # 14 -> 6
            # Efficient Grid Size Reduction
            x = self.Mixed_7a(x)

            # 6 -> 6
            # We are using this solution only on the coarsest grid,
            # since that is the place where producing high dimensional
            # sparse representation is the most critical as the ratio of
            # local processing (by 1 × 1 convolutions) is increased compared
            # to the spatial aggregation."""
            x = self.Mixed_7b(x)
            x = self.Mixed_7c(x)

            # 6 -> 1
            x = self.avgpool(x)
            x = self.dropout(x)
            x = x.view(x.size(0), -1)
            x = self.embedding_recorder(x)
            x = self.linear(x)
        return x


class InceptionV3_224x224(inception.Inception3):
    def __init__(self, channel: int, num_classes: int, record_embedding: bool = False,
                 no_grad: bool = False, **kwargs):
        super().__init__(num_classes=num_classes, **kwargs)
        self.embedding_recorder = EmbeddingRecorder(record_embedding)
        if channel != 3:
            self.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
        self.no_grad = no_grad

    def get_last_layer(self):
        return self.fc

    def _forward(self, x):
        with torch.set_grad_enabled(not self.no_grad):
            # N x 3 x 299 x 299
            x = self.Conv2d_1a_3x3(x)
            # N x 32 x 149 x 149
            x = self.Conv2d_2a_3x3(x)
            # N x 32 x 147 x 147
            x = self.Conv2d_2b_3x3(x)
            # N x 64 x 147 x 147
            x = self.maxpool1(x)
            # N x 64 x 73 x 73
            x = self.Conv2d_3b_1x1(x)
            # N x 80 x 73 x 73
            x = self.Conv2d_4a_3x3(x)
            # N x 192 x 71 x 71
            x = self.maxpool2(x)
            # N x 192 x 35 x 35
            x = self.Mixed_5b(x)
            # N x 256 x 35 x 35
            x = self.Mixed_5c(x)
            # N x 288 x 35 x 35
            x = self.Mixed_5d(x)
            # N x 288 x 35 x 35
            x = self.Mixed_6a(x)
            # N x 768 x 17 x 17
            x = self.Mixed_6b(x)
            # N x 768 x 17 x 17
            x = self.Mixed_6c(x)
            # N x 768 x 17 x 17
            x = self.Mixed_6d(x)
            # N x 768 x 17 x 17
            x = self.Mixed_6e(x)
            # N x 768 x 17 x 17
            aux = None
            if self.AuxLogits is not None:
                if self.training:
                    aux = self.AuxLogits(x)
            # N x 768 x 17 x 17
            x = self.Mixed_7a(x)
            # N x 1280 x 8 x 8
            x = self.Mixed_7b(x)
            # N x 2048 x 8 x 8
            x = self.Mixed_7c(x)
            # N x 2048 x 8 x 8
            # Adaptive average pooling
            x = self.avgpool(x)
            # N x 2048 x 1 x 1
            x = self.dropout(x)
            # N x 2048 x 1 x 1
            x = torch.flatten(x, 1)
            # N x 2048
            x = self.embedding_recorder(x)
            x = self.fc(x)
            # N x 1000 (num_classes)
            return x, aux


def InceptionV3(channel: int, num_classes: int, im_size, record_embedding: bool = False, no_grad: bool = False,
                pretrained: bool = False):
    if pretrained:
        if im_size[0] != 224 or im_size[1] != 224:
            raise NotImplementedError("torchvison pretrained models only accept inputs with size of 224*224")
        net = InceptionV3_224x224(channel=3, num_classes=1000, record_embedding=record_embedding, no_grad=no_grad)

        from torch.hub import load_state_dict_from_url
        state_dict = load_state_dict_from_url(inception.model_urls["inception_v3_google"], progress=True)
        net.load_state_dict(state_dict)

        if channel != 3:
            net.Conv2d_1a_3x3 = inception.conv_block(channel, 32, kernel_size=3, stride=2)
        if num_classes != 1000:
            net.fc = nn.Linear(net.fc.in_features, num_classes)

    elif im_size[0] == 224 and im_size[1] == 224:
        net = InceptionV3_224x224(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
                                  no_grad=no_grad)
    elif (channel == 1 and im_size[0] == 28 and im_size[1] == 28) or (
            channel == 3 and im_size[0] == 32 and im_size[1] == 32):
        net = InceptionV3_32x32(channel=channel, num_classes=num_classes, record_embedding=record_embedding,
                                no_grad=no_grad)
    else:
        raise NotImplementedError("Network Architecture for current dataset has not been implemented.")

    return net
