from collections import OrderedDict
import torch
import torch.nn as nn
from pytorchcv.model_provider import get_model as ptcv_get_model
from models import replace_bn_with_brn


class Classifier(nn.Module):
    def __init__(self, n_classes, weights_file):
        super(Classifier, self).__init__()
        self.network = self._define_classifier(n_classes)
        replace_bn_with_brn(self.network, momentum=0.01, r_d_max_inc_step=4.1e-05, r_max=1.25, d_max=0.5)
        if weights_file is not None:
            self.load_state_dict(torch.load(weights_file))

    def forward(self, x):
        return self.network(x)

    @staticmethod
    def _define_classifier(n_classes):
        net = ptcv_get_model("mobilenet_w1", pretrained=True)
        net.features.final_pool = torch.nn.AvgPool2d(4)
        net.output = torch.nn.Linear(1024, n_classes)
        net.features.stage4.unit5 = net.features.stage4.unit5.pw_conv
        net.features.stage4 = nn.Sequential(
            net.features.stage4.unit5,
            net.features.stage4.unit6
        )
        net.features = nn.Sequential(OrderedDict([
            ('stage4', net.features.stage4),
            ('stage5', net.features.stage5),
            ('pool', net.features.final_pool)
        ]))

        return net


class ClassifierBN(nn.Module):
    def __init__(self, n_classes, weights_file):
        super(ClassifierBN, self).__init__()
        self.network = self._define_classifier(n_classes)
        # replace_bn_with_brn(self.network, momentum=0.01, r_d_max_inc_step=4.1e-05, r_max=1.25, d_max=0.5)
        if weights_file is not None:
            print("Load pretrained core50 classifier")
            self.load_state_dict(torch.load(weights_file))

    def forward(self, x):
        return self.network(x)

    @staticmethod
    def _define_classifier(n_classes):
        net = ptcv_get_model("mobilenet_w1", pretrained=True)
        net.features.final_pool = torch.nn.AvgPool2d(4)
        net.output = torch.nn.Linear(1024, n_classes)
        net.features.stage4.unit5 = net.features.stage4.unit5.pw_conv
        net.features.stage4 = nn.Sequential(
            net.features.stage4.unit5,
            net.features.stage4.unit6
        )
        net.features = nn.Sequential(OrderedDict([
            ('stage4', net.features.stage4),
            ('stage5', net.features.stage5),
            ('pool', net.features.final_pool)
        ]))

        return net
