"""
Replications of models from Frankle et al. Lottery Ticket Hypothesis
"""
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *
from approaches.supsup.models.builder import Builder
from approaches.supsup.args import args


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        builder = Builder()
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(300 * args.width_mult), first_layer=True),
            nn.ReLU(),
            builder.conv1x1(int(300 * args.width_mult), int(100 * args.width_mult)),
            nn.ReLU(),
            builder.conv1x1(int(100 * args.width_mult), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()


class FC1024(nn.Module):
    def __init__(self):
        super(FC1024, self).__init__()
        builder = Builder()
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(args.width_mult * 1024), first_layer=True),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * 1024), int(args.width_mult * 1024)),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * 1024), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()


class FC2048(nn.Module):
    def __init__(self, ch: int, w: int, h: int, list__ncls: List[int]):
        n = 2048
        self.ch, self.w, self.h = ch, w, h

        super(FC2048, self).__init__()
        builder = Builder()
        self.linear = nn.Sequential(
            builder.conv1x1(ch * w * h, int(args.width_mult * n), first_layer=True),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * n), int(args.width_mult * n)),
            nn.ReLU(),
            builder.conv1x1(int(args.width_mult * n), args.output_size, last_layer=True),
            )
        self.fc = nn.ModuleList()
        for ncls in list__ncls:
            self.fc.append(nn.Linear(args.output_size, ncls))
        # endfor
    # enddef

    def forward(self, x, idx_task: int):
        out = x.view(x.size(0), self.ch * self.w * self.h, 1, 1)
        out = self.linear(out)
        out = out.squeeze()
        fc = self.fc[idx_task]
        # out = fc(out)
        return out


class BNNet(nn.Module):
    def __init__(self):
        super(BNNet, self).__init__()
        builder = Builder()
        dim = 2048
        self.linear = nn.Sequential(
            builder.conv1x1(28 * 28, int(dim * args.width_mult), first_layer=True),
            builder.batchnorm(int(dim * args.width_mult)),
            Swish(),
            builder.conv1x1(int(dim * args.width_mult), int(dim * args.width_mult)),
            builder.batchnorm(int(dim * args.width_mult)),
            Swish(),
            builder.conv1x1(int(dim * args.width_mult), args.output_size, last_layer=True),
            )

    def forward(self, x):
        out = x.view(x.size(0), 28 * 28, 1, 1)
        out = self.linear(out)
        return out.squeeze()
