from typing import *

import numpy as np
import torch
import torch.nn as nn

from approaches.ucl.bayes_layer import BayesianConv2D, BayesianLinear


class ModelUCLAlexNet(nn.Module):
    def __init__(self, taskcla: List[Tuple[int, int]], inputsize: Tuple[int, ...],
                 ratio: float, nhid: int, drop1: float, drop2: float):
        super().__init__()

        ncha, size, _ = inputsize
        self.ratio = ratio

        self.taskcla = taskcla

        self.conv1 = BayesianConv2D(ncha, 64, kernel_size=size // 8, ratio=ratio)
        s = compute_conv_output_size(size, size // 8)
        s = s // 2
        self.conv2 = BayesianConv2D(64, 128, kernel_size=size // 10, ratio=ratio)
        s = compute_conv_output_size(s, size // 10)
        s = s // 2
        self.conv3 = BayesianConv2D(128, 256, kernel_size=2, ratio=ratio)
        s = compute_conv_output_size(s, 2)
        s = s // 2
        self.maxpool = torch.nn.MaxPool2d(2)
        self.relu = torch.nn.ReLU()

        self.drop1 = torch.nn.Dropout(drop1)
        self.drop2 = torch.nn.Dropout(drop2)
        self.fc1 = BayesianLinear(256 * s * s, nhid, ratio=ratio)
        self.fc2 = BayesianLinear(nhid, nhid, ratio=ratio)
        self.old_weight_norm = []

        self.last = torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.last.append(torch.nn.Linear(nhid, n))
        # endfor
    # enddef

    def forward(self, x, sample=False):
        h = x

        # h = self.maxpool(self.drop1(self.relu(self.conv1(h, sample))))
        # h = self.maxpool(self.drop1(self.relu(self.conv2(h, sample))))
        # h = self.maxpool(self.drop2(self.relu(self.conv3(h, sample))))

        h = self.drop1(self.maxpool(self.relu(self.conv1(h, sample))))
        h = self.drop1(self.maxpool(self.relu(self.conv2(h, sample))))
        h = self.drop2(self.maxpool(self.relu(self.conv3(h, sample))))

        h = h.view(x.size(0), -1)
        h = self.drop2(self.relu(self.fc1(h, sample)))
        h = self.drop2(self.relu(self.fc2(h, sample)))

        y = []
        for t, i in self.taskcla:
            y.append(self.last[t](h))
        # endfor

        return y
    # endfor


def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
    return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))
