from typing import *

import torch
import torch.nn as nn

from approaches.ucl.bayes_layer import BayesianLinear


class ModelUCLMLP(nn.Module):
    def __init__(self, taskcla: List[Tuple[int, int]], inputsize: Tuple[int, ...],
                 ratio: float, nhid: int, drop1: float, drop2: float):
        super().__init__()

        ncha, size, _ = inputsize
        self.ratio = ratio

        self.taskcla = taskcla

        self.relu = torch.nn.ReLU()

        self.drop1 = torch.nn.Dropout(drop1)
        self.drop2 = torch.nn.Dropout(drop2)
        self.fc1 = BayesianLinear(ncha * size * size, nhid, ratio=ratio)
        self.fc2 = BayesianLinear(nhid, nhid, ratio=ratio)
        self.old_weight_norm = []

        self.last = torch.nn.ModuleList()
        self.merge_last = torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.last.append(torch.nn.Linear(nhid, n))
            self.merge_last.append(torch.nn.Linear(n * 2, n))
        # endfor
    # enddef

    def forward(self, x, sample=False):
        h = x
        h = self.drop1(h.view(x.size(0), -1))
        h = self.drop2(self.relu(self.fc1(h, sample)))
        h = self.drop2(self.relu(self.fc2(h, sample)))

        y = []
        for t, i in self.taskcla:
            y.append(self.last[t](h))
        # endfor

        return y
    # endfor
