import torch.nn as nn
import torch.nn.functional as F
from typing import Iterable

from xad.models.bases import ADNN


class CNN28(ADNN):
    def __init__(self, rep_dim=32, bias=False, clf=False, grayscale=True):
        super().__init__()
        self.clf = clf
        self.grayscale = grayscale

        self.rep_dim = rep_dim
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(1 if grayscale else 3, 16, 5, bias=bias, padding=2)
        self.bn2d1 = nn.BatchNorm2d(16, eps=1e-04, affine=bias)
        self.conv2 = nn.Conv2d(16, 32, 5, bias=bias, padding=2)
        self.bn2d2 = nn.BatchNorm2d(32, eps=1e-04, affine=bias)
        self.fc1 = nn.Linear(32 * 7 * 7, 64, bias=bias)
        self.bn1d1 = nn.BatchNorm1d(64, eps=1e-04, affine=bias)
        self.fc2 = nn.Linear(64, self.rep_dim, bias=bias)

        if self.clf:
            self.linear = nn.Linear(self.rep_dim, 1)

    def forward(self, x, return_encoding=False):
        x = x.reshape(-1, 1 if self.grayscale else 3, 28, 28)
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn2d1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2d2(x)))
        x = x.reshape(int(x.size(0)), -1)
        x = self.fc1(x)
        encoding = x
        x = F.leaky_relu(self.bn1d1(x))
        x = self.fc2(x)
        if self.clf:
            encoding = x
            x = self.linear(x)
        if return_encoding:
            return x, encoding
        else:
            return x


class GenericCNN(ADNN):
    def __init__(self, bias=False, clf=False, grayscale=True, dropout=0.0,
                 conv_layer: Iterable[int] = (16, 32, ), fc_layer: Iterable[int] = (64, 32, ), ksize: int = 3,
                 spatial_dims: int = 28):
        super().__init__()
        self.clf = clf
        self.grayscale = grayscale
        self.dropout = dropout
        self.ksize = ksize

        self.conv_layer_dims = list(conv_layer)
        self.conv_layer_dims.insert(0, 1 if grayscale else 3)
        self.conv_layer = nn.Sequential(*[
            l for dim_in, dim_out in zip(self.conv_layer_dims, self.conv_layer_dims[1:]) for l in (
                nn.Conv2d(dim_in, dim_out, ksize, bias=bias, padding=ksize//2, ),
                nn.BatchNorm2d(dim_out, eps=1e-04, affine=bias),
                nn.LeakyReLU(inplace=True),
                nn.MaxPool2d(2, 2),
                nn.Dropout(dropout),
            )
        ])

        self.spatial_dims = [spatial_dims]
        for _ in self.conv_layer[::5]:
            self.spatial_dims.append((self.spatial_dims[-1] - 2) // 2 + 1)

        self.fc_layer_dims = list(fc_layer)
        self.fc_layer_dims.insert(0, self.conv_layer_dims[-1] * self.spatial_dims[-1] ** 2)
        self.fc_layer = nn.Sequential(nn.Flatten(), *[
            l for dim_in, dim_out in zip(self.fc_layer_dims, self.fc_layer_dims[1:]) for l in (
                nn.Linear(dim_in, dim_out, bias=bias),
                nn.BatchNorm1d(dim_out, eps=1e-04, affine=bias),
                nn.LeakyReLU(inplace=True),
                nn.Dropout(dropout),
            )
        ])
        del self.fc_layer[-1]
        del self.fc_layer[-1]
        del self.fc_layer[-1]

        if self.clf:
            self.linear = nn.Linear(self.fc_layer_dims[-1], 1)

    def forward(self, x, return_encoding=False):
        x = self.conv_layer(x)
        x = self.fc_layer(x)
        if return_encoding:
            raise NotImplementedError()
        if self.clf:
            x = self.linear(x)
            return x
        else:
            return x


class GenericCNN28(GenericCNN):  # legacy
    pass


class CNN32(ADNN):
    def __init__(self, rep_dim=256, bias=False, clf=False, grayscale=False):
        super().__init__()
        self.clf = clf
        self.grayscale = grayscale

        self.rep_dim = rep_dim
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(3 if not grayscale else 1, 32, 5, bias=bias, padding=2)
        self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=bias)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=bias, padding=2)
        self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=bias)
        self.conv3 = nn.Conv2d(64, 128, 5, bias=bias, padding=2)
        self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=bias)
        self.fc1 = nn.Linear(128 * 4 * 4, 512, bias=bias)
        self.bn1d1 = nn.BatchNorm1d(512, eps=1e-04, affine=bias)
        self.fc2 = nn.Linear(512, self.rep_dim, bias=bias)

        if self.clf:
            self.linear = nn.Linear(self.rep_dim, 1)

    def forward(self, x, return_encoding=False):
        x = x.reshape(-1, 3 if not self.grayscale else 1, 32, 32)
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn2d1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2d2(x)))
        x = self.conv3(x)
        x = self.pool(F.leaky_relu(self.bn2d3(x)))
        x = x.reshape(int(x.size(0)), -1)
        x = self.fc1(x)
        encoding = x
        x = F.leaky_relu(self.bn1d1(x))
        x = self.fc2(x)
        if self.clf:
            encoding = x
            x = self.linear(x)
        if return_encoding:
            return x, encoding
        else:
            return x


class CNN32_MCLF(CNN32):
    def __init__(self, classes: int, *args, **kwargs):
        super().__init__(*args, **kwargs, bias=True, clf=False)
        self.linear = nn.Linear(self.rep_dim, classes)
        self.clf = True


class CNN64(ADNN):
    def __init__(self, rep_dim=256, bias=False, clf=False, grayscale=False):
        super().__init__()
        self.clf = clf

        self.rep_dim = rep_dim
        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = nn.Sequential(
            nn.Conv2d(3 if not grayscale else 1, 32, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(32, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(64, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(128, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(256, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(4096, 1024, bias=bias),
            nn.BatchNorm1d(1024, affine=bias),
            nn.LeakyReLU(),
            nn.Linear(1024, self.rep_dim, bias=bias)
        )

        if self.clf:
            self.linear = nn.Linear(self.rep_dim, 1)

    def forward(self, x, return_encoding=False):
        fx = self.encoder(x)
        logits = self.linear(fx) if self.clf else fx
        if return_encoding:
            return logits, fx
        else:
            return logits


class FFN64(ADNN):
    def __init__(self, rep_dim=32, bias=False, clf=False, grayscale=False, fcn=False):
        super().__init__()
        self.clf = clf
        self.fcn = fcn

        self.rep_dim = rep_dim
        self.pool = nn.MaxPool2d(2, 2)
        in_channels = 3 if not grayscale else 1

        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels * 64 ** 2, 32, bias=bias),
            nn.BatchNorm1d(32, affine=bias),
            nn.ReLU(inplace=True),
            nn.Linear(32, self.rep_dim, bias=bias),
        )

        if self.clf:
            self.linear = nn.Sequential(
                nn.BatchNorm1d(self.rep_dim, affine=bias),
                nn.ReLU(inplace=True),
                nn.Linear(self.rep_dim, 1)
            )

    def forward(self, x, return_encoding=False):
        fx = self.encoder(x)
        logits = self.linear(fx) if self.clf else fx
        if return_encoding:
            return logits, fx
        else:
            return logits
