import math
import torch.nn as nn
from layers import BBB_Linear, BBB_Conv2d
from layers import BBB_LRT_Linear, BBB_LRT_Conv2d
from layers import FlattenLayer, ModuleWrapper


class BBB2Fc(ModuleWrapper):
    """

    Simple Neural Network having 1 FC layers with Bayesian layers.
    """

    def __init__(self, outputs, inputs, dim, priors, layer_type='lrt', activation_type='softplus'):
        super(BBB2Fc, self).__init__()

        self.num_classes = outputs
        self.layer_type = layer_type
        self.priors = priors

        if layer_type == 'lrt':
            BBBLinear = BBB_LRT_Linear
            BBBConv2d = BBB_LRT_Conv2d
        elif layer_type == 'bbb':
            BBBLinear = BBB_Linear
            BBBConv2d = BBB_Conv2d
        else:
            raise ValueError("Undefined layer_type")

        if activation_type == 'softplus':
            self.act = nn.Softplus
        elif activation_type == 'relu':
            self.act = nn.ReLU
        else:
            raise ValueError("Only softplus or relu supported")


        self.flatten = FlattenLayer(inputs * dim*dim)
        self.fc1 = BBBLinear(inputs * dim*dim,1024, bias=True, priors=self.priors)
        self.act1 = self.act()

        self.fc2 = BBBLinear(1024, 1024, bias=True, priors=self.priors)
        self.act2 = self.act()

        self.classifier = BBBLinear(1024, outputs, bias=True, priors=self.priors)

    def reset_priors_networks(self):


        self.fc1.reset_prior()
        self.fc2.reset_prior()
        self.classifier.reset_prior()

    def reset_params(self):
        self.fc1.reset_parameters()
        self.fc2.reset_parameters()
        self.classifier.reset_parameters()