"""
Replications of models from Frankle et al. Lottery Ticket Hypothesis
"""

from typing import *

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

from approaches.supsup.args import args
from approaches.supsup.models.builder import Builder


class MLP(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], nhid: int, drop1: float, drop2: float):
        super().__init__()
        builder = Builder()

        self.fc1 = builder.conv1x1(inputsize[0] * inputsize[1] * inputsize[2], int(args.width_mult * nhid), first_layer=True)
        self.fc2 = builder.conv1x1(int(args.width_mult * nhid), int(args.width_mult * nhid))
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        ncls_max = max(list__ncls)
        self.clf = builder.conv1x1(int(args.width_mult * nhid), ncls_max, last_layer=True)
    # enddef

    def forward(self, x: Tensor, idx_task: int):
        h = x.view(x.shape[0], -1, 1, 1)
        h = self.drop1(h)
        h = self.drop2(self.relu(self.fc1(h)))
        h = self.drop2(self.relu(self.fc2(h)))

        out = self.clf(h)

        return out.squeeze()
    # enddef


def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))


class AlexNet(nn.Module):
    def __init__(self, list__ncls: List[int], inputsize: Tuple[int, ...], nhid: int, drop1: float, drop2: float):
        super().__init__()
        builder = Builder()
        pad = 0

        nch, size = inputsize[0], inputsize[1]

        self.c1 = builder.myconv(size // 8, nch, 64, padding=pad, first_layer=True)
        s = compute_conv_output_size(size, size // 8, padding=pad)
        s = s // 2

        self.c2 = builder.myconv(size // 10, 64, 128, padding=pad)
        s = compute_conv_output_size(s, size // 10, padding=pad)
        s = s // 2

        self.c3 = builder.myconv(2, 128, 256, padding=pad)
        s = compute_conv_output_size(s, 2, padding=pad)
        s = s // 2

        self.smid = s
        self.maxpool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)

        self.fc1 = builder.conv1x1(int(args.width_mult * 256 * self.smid ** 2), int(args.width_mult * nhid))
        self.fc2 = builder.conv1x1(int(args.width_mult * nhid), int(args.width_mult * nhid))

        ncls_max = max(list__ncls)
        self.clf = builder.conv1x1(int(args.width_mult * nhid), ncls_max, last_layer=True)
    # enddef

    def forward(self, x: Tensor, idx_task: int):
        h = x
        h = self.maxpool(self.drop1(self.relu(self.c1(h))))
        h = self.maxpool(self.drop1(self.relu(self.c2(h))))
        h = self.maxpool(self.drop2(self.relu(self.c3(h))))
        h = h.view(h.shape[0], -1, 1, 1)
        h = self.drop2(self.relu(self.fc1(h)))
        h = self.drop2(self.relu(self.fc2(h)))

        out = self.clf(h)

        return out.squeeze()
    # enddef


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        builder = Builder()
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(300 * args.width_mult), first_layer=True),
            nn.ReLU(),
            builder.conv1x1(int(300 * args.width_mult), int(100 * args.width_mult)),
            nn.ReLU(),
            builder.conv1x1(int(100 * args.width_mult), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()


class FC1024(nn.Module):
    def __init__(self):
        super(FC1024, self).__init__()
        builder = Builder()
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(args.width_mult * 1024), first_layer=True),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * 1024), int(args.width_mult * 1024)),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * 1024), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()


class BNNet(nn.Module):
    def __init__(self):
        super(BNNet, self).__init__()
        builder = Builder()
        dim = 2048
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(dim * args.width_mult), first_layer=True),
            builder.batchnorm(int(dim * args.width_mult)),
            Swish(),
            builder.conv1x1(int(dim * args.width_mult), int(dim * args.width_mult)),
            builder.batchnorm(int(dim * args.width_mult)),
            Swish(),
            builder.conv1x1(int(dim * args.width_mult), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()
