import torch.nn as nn
import torch
import numpy as np
from .FeatureExtractorText import make_res_block_encoder_feature_extractor
import torch.nn.functional as F


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ResidualBlockEncoder(nn.Module):
    def __init__(self, channels_in, channels_out, kernelsize, stride, padding, dilation, downsample):
        super(ResidualBlockEncoder, self).__init__()
        self.bn1 = nn.BatchNorm1d(channels_in)
        self.conv1 = nn.Conv1d(channels_in, channels_in, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm1d(channels_in)
        self.conv2 = nn.Conv1d(channels_in, channels_out, kernel_size=kernelsize, stride=stride, padding=padding, dilation=dilation)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out = residual + 0.3*out
        return out

def make_res_block_encoder(channels_in, channels_out, kernelsize, stride, padding, dilation):
    downsample = None
    if (stride != 1) or (channels_in != channels_out) or dilation != 1:
        downsample = nn.Sequential(nn.Conv1d(channels_in, channels_out,
                                             kernel_size=kernelsize,
                                             stride=stride,
                                             padding=padding,
                                             dilation=dilation),
                                   nn.BatchNorm1d(channels_out))
    layers = []
    layers.append(ResidualBlockEncoder(channels_in, channels_out, kernelsize, stride, padding, dilation, downsample))
    return nn.Sequential(*layers)

class ResidualBlock2dConv(nn.Module):
    def __init__(self, channels_in, channels_out, kernelsize, stride, padding, dilation, downsample, a=1, b=1):
        super(ResidualBlock2dConv, self).__init__()
        self.conv1 = nn.Conv2d(channels_in, channels_in, kernel_size=1, stride=1, padding=0, dilation=dilation, bias=False)
        self.dropout1 = nn.Dropout2d(p=0.5, inplace=False)
        self.bn1 = nn.BatchNorm2d(channels_in)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm2d(channels_in)
        self.conv2 = nn.Conv2d(channels_in, channels_out, kernel_size=kernelsize, stride=stride, padding=padding, dilation=dilation, bias=False)
        self.dropout2 = nn.Dropout2d(p=0.5, inplace=False)
        self.downsample = downsample
        self.a = a
        self.b = b

    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.dropout1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.dropout2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out = self.a*residual + self.b*out
        return out

def make_res_block_feature_extractor(in_channels, out_channels, kernelsize, stride, padding, dilation, a_val=2.0, b_val=0.3):
    downsample = None
    if (stride != 2) or (in_channels != out_channels):
        downsample = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                             kernel_size=kernelsize,
                                             padding=padding,
                                             stride=stride,
                                             dilation=dilation),
                                   nn.BatchNorm2d(out_channels))
    layers = []
    layers.append(ResidualBlock2dConv(in_channels, out_channels, kernelsize, stride, padding, dilation, downsample,a=a_val, b=b_val))
    return nn.Sequential(*layers)


def make_res_layers_feature_extractor(args, a=1.0, b=1.0):
    blocks = []
    for k in range(0, args.num_layers_img):
        channels_in = (k+1)*args.DIM_img
        channels_out = min(k+2, args.num_layers_img)*args.DIM_img
        res_block = make_res_block_feature_extractor(channels_in,
                                                     channels_out,
                                                     kernelsize=args.kernelsize_enc_img,
                                                     stride=args.enc_stride_img,
                                                     padding=args.enc_padding_img,
                                                     dilation=1,
                                                     a_val=a,
                                                     b_val=b)
        blocks.append(res_block)
    return nn.Sequential(*blocks)


class FeatureExtractorImg(nn.Module):
    def __init__(self, a, b, image_channels=3, DIM_img=128, kernelsize_enc_img=3, enc_stride_img=2, enc_padding_img=1):
        super(FeatureExtractorImg, self).__init__()
        self.a = a
        self.b = b
        self.conv1 = nn.Conv2d(image_channels, DIM_img,
                              kernel_size=kernelsize_enc_img,
                              stride=enc_stride_img,
                              padding=enc_padding_img,
                              dilation=1,
                              bias=False)
        self.resblock1 = make_res_block_feature_extractor(DIM_img, 2 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=a, b_val=b)
        self.resblock2 = make_res_block_feature_extractor(2 * DIM_img, 3 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=self.a, b_val=self.b)
        self.resblock3 = make_res_block_feature_extractor(3 * DIM_img, 4 * DIM_img, kernelsize=4, stride=2,
                                                          padding=1, dilation=1, a_val=self.a, b_val=self.b)
        self.resblock4 = make_res_block_feature_extractor(4 * DIM_img, 5 * DIM_img, kernelsize=4, stride=2,
                                                          padding=0, dilation=1, a_val=self.a, b_val=self.b)

    def forward(self, x):
        out = self.conv1(x)
        out = self.resblock1(out)
        out = self.resblock2(out)
        out = self.resblock3(out)
        out = self.resblock4(out)
        return out


class ClfImgMNIST(nn.Module):
    def __init__(self):
        super(ClfImgMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5, inplace=False)
        self.linear = nn.Linear(in_features=128, out_features=10, bias=True)  # 10 is the number of classes (=digits)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h = self.conv1(x)
        h = self.relu(h)
        h = self.conv2(h)
        h = self.relu(h)
        h = self.conv3(h)
        h = self.relu(h)
        h = self.dropout(h)
        h = h.view(h.size(0), -1)
        h = self.linear(h)
        out = self.sigmoid(h)
        return out

class ClfImgMMNIST(nn.Module):
    """
    MNIST image-to-digit classifier. Roughly based on the encoder from:
    https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb
    """
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(                          # input shape (3, 28, 28)
            nn.Conv2d(3, 10, kernel_size=4, stride=2, padding=1),     # -> (10, 14, 14)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=4, stride=2, padding=1),    # -> (20, 7, 7)
            nn.Dropout2d(0.5),
            nn.ReLU(),
            Flatten(),                                                # -> (980)
            nn.Linear(980, 128),                                      # -> (128)
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(128, 10)                                        # -> (10)
        )

    def forward(self, x):
        h = self.encoder(x)
        return h


def actvn(x):
    out = F.leaky_relu(x, 2e-1)
    return out


class ClfImgTransMMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        s0 = self.s0 = 7
        nf = self.nf = 64
        nf_max = self.nf_max = 1024
        size = 28

        # Submodules
        nlayers = int(np.log2(size / s0))
        self.nf0 = min(nf_max, nf * 2**nlayers)

        blocks = [
            ResnetBlock(nf, nf)
        ]

        for i in range(nlayers):
            nf0 = min(nf * 2**i, nf_max)
            nf1 = min(nf * 2**(i+1), nf_max)
            blocks += [
                nn.AvgPool2d(3, stride=2, padding=1),
                ResnetBlock(nf0, nf1),
            ]

        self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
        self.resnet = nn.Sequential(*blocks)
        self.fc = nn.Linear(self.nf0*s0*s0, 10)

    def forward(self, x):
        batch_size = x.size(0)
        out = self.conv_img(x)
        out = self.resnet(out)
        out = out.view(batch_size, self.nf0*self.s0*self.s0)
        out = self.fc(actvn(out))
        return out


class ResnetBlock(nn.Module):
    def __init__(self, fin, fout, fhidden=None, is_bias=True):
        super().__init__()
        # Attributes
        self.is_bias = is_bias
        self.learned_shortcut = (fin != fout)
        self.fin = fin
        self.fout = fout
        if fhidden is None:
            self.fhidden = min(fin, fout)
        else:
            self.fhidden = fhidden

        # Submodules
        self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
        self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
        if self.learned_shortcut:
            self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x_s = self._shortcut(x)
        dx = self.conv_0(actvn(x))
        dx = self.conv_1(actvn(dx))
        out = x_s + 0.1*dx

        return out

    def _shortcut(self, x):
        if self.learned_shortcut:
            x_s = self.conv_s(x)
        else:
            x_s = x
        return x_s


class ClfImgSVHN(nn.Module):
    def __init__(self):
        super(ClfImgSVHN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=0, dilation=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5, inplace=False)
        self.linear = nn.Linear(in_features=128, out_features=10, bias=True)  # 10 is the number of classes (=digits)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h = self.conv1(x)
        h = self.dropout(h)
        h = self.bn1(h)
        h = self.relu(h)
        h = self.conv2(h)
        h = self.dropout(h)
        h = self.bn2(h)
        h = self.relu(h)
        h = self.conv3(h)
        h = self.dropout(h)
        h = self.bn3(h)
        h = self.relu(h)
        h = self.conv4(h)
        h = self.dropout(h)
        h = self.bn4(h)
        h = self.relu(h)
        h = h.view(h.size(0), -1)
        h = self.linear(h)
        out = self.sigmoid(h)
        return out


class ClfText(nn.Module):
    def __init__(self, num_features, dim):
        super(ClfText, self).__init__()
        self.conv1 = nn.Conv1d(num_features, 2 * dim, kernel_size=1)
        self.resblock_1 = make_res_block_encoder(2 * dim, 3 * dim, kernelsize=4, stride=2, padding=1,
                                                 dilation=1)
        self.resblock_4 = make_res_block_encoder(3 * dim, 2 * dim, kernelsize=4, stride=2, padding=0,
                                                 dilation=1)
        self.dropout = nn.Dropout(p=0.5, inplace=False)
        self.linear = nn.Linear(in_features=2*dim, out_features=10, bias=True) # 10 is the number of classes
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.transpose(-2,-1)
        h = self.conv1(x)
        h = self.resblock_1(h)
        h = self.resblock_4(h)
        h = self.dropout(h)
        h = h.view(h.size(0), -1)
        h = self.linear(h)
        out = self.sigmoid(h)
        return out


class ClfMNISTLabel(nn.Module):
    def __init__(self):
        super(ClfMNISTLabel, self).__init__()

    def forward(self, x):
        return x

class ClfCelebAImg(nn.Module):
    def __init__(self, a_img=2.0, b_img=0.3, num_layers_img=5, DIM_img=128):
        super(ClfCelebAImg, self).__init__()
        self.feature_extractor = FeatureExtractorImg(a=a_img, b=b_img)
        self.dropout = nn.Dropout(p=0.5, inplace=False)
        self.linear = nn.Linear(in_features=num_layers_img*DIM_img, out_features=40, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_img):
        h = self.feature_extractor(x_img)
        h = self.dropout(h)
        h = h.view(h.size(0), -1)
        h = self.linear(h)
        out = self.sigmoid(h)
        return out

    def get_activations(self, x_img):
        h = self.feature_extractor(x_img)
        return h

class ClfCelebAAttribute(nn.Module):
    def __init__(self):
        super(ClfCelebAAttribute, self).__init__()

    def forward(self, y):
        return y

class ClfCelebAText(nn.Module):
    def __init__(self, num_features=71, DIM_text=128, enc_padding_text=1, num_layers_text=7):
        super(ClfCelebAText, self).__init__();
        self.conv1 = nn.Conv1d(num_features, DIM_text,
                               kernel_size=3, stride=2, padding=enc_padding_text, dilation=1);
        self.resblock_1 = make_res_block_encoder_feature_extractor(DIM_text, 2*DIM_text,
                                                                   kernelsize=4, stride=2, padding=1, dilation=1);
        self.resblock_2 = make_res_block_encoder_feature_extractor(2*DIM_text, 3*DIM_text,
                                                                   kernelsize=4, stride=2, padding=1, dilation=1);
        self.resblock_3 = make_res_block_encoder_feature_extractor(3*DIM_text, 4*DIM_text,
                                                                   kernelsize=4, stride=2, padding=1, dilation=1);
        self.resblock_4 = make_res_block_encoder_feature_extractor(4*DIM_text, 5*DIM_text,
                                                                   kernelsize=4, stride=2, padding=1, dilation=1);
        self.resblock_5 = make_res_block_encoder_feature_extractor(5*DIM_text, 6*DIM_text,
                                                                   kernelsize=4, stride=2, padding=1, dilation=1);
        self.resblock_6 = make_res_block_encoder_feature_extractor(6*DIM_text, 7*DIM_text,
                                                                   kernelsize=4, stride=2, padding=0, dilation=1);
        self.dropout = nn.Dropout(p=0.5, inplace=False);
        self.linear = nn.Linear(in_features=num_layers_text*DIM_text, out_features=40, bias=True)
        self.sigmoid = nn.Sigmoid();


    def forward(self, x_text):
        x_text = x_text.transpose(-2,-1);
        out = self.conv1(x_text)
        out = self.resblock_1(out);
        out = self.resblock_2(out);
        out = self.resblock_3(out);
        out = self.resblock_4(out);
        out = self.resblock_5(out);
        out = self.resblock_6(out);
        h = self.dropout(out);
        h = h.view(h.size(0), -1);
        h = self.linear(h);
        out = self.sigmoid(h)
        return out;