import torch
from torch.nn import Identity
from utils import gaussian_dropout
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
from torchvision.models import densenet121, densenet169, densenet201
from collections import OrderedDict


class BaseNet(torch.nn.Module):
    def __init__(self, model, in_channels=3):
        super().__init__()

        assert model in [
            'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
            'densenet121', 'densenet169', 'densenet201',
        ]

        if model == 'resnet18':
            self._model = resnet18(pretrained=False)
            self.fc_in_features = 512
        elif model == 'resnet34':
            self._model = resnet34(pretrained=False,)
            self.fc_in_features = 512
        elif model == 'resnet50':
            self._model = resnet50(pretrained=False)
            self.fc_in_features = 2048
        elif model == 'resnet101':
            self._model = resnet101(pretrained=False)
            self.fc_in_features = 2048
        elif model == 'resnet152':
            self._model = resnet152(pretrained=False)
            self.fc_in_features = 2048
        elif model == 'densenet121':
            self._model = densenet121(pretrained=False)
            self.fc_in_features = 1024
        elif model =='densenet169':
            self._model = densenet169(pretrained=False)
            self.fc_in_features = 1664
        elif model == 'densenet201':
            self._model = densenet201(pretrained=False)
            self.fc_in_features = 1920
        else:
            assert False

        # skip last linear layer and fix fast downsampling for smaller CIFAR images
        if 'resnet' in model:
            self._model.fc = Identity()

            self._model.conv1 = torch.nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self._model.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        elif 'densenet' in model:
            self._model.classifier = Identity()

            self.features = torch.nn.Sequential(OrderedDict([
                ('conv0', torch.nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False)),
                ('norm0', torch.nn.BatchNorm2d(64)),
                ('relu0', torch.nn.ReLU(inplace=True)),
                ('pool0', torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)),
            ]))

    def forward(self, x):
        return self._model(x)


class FrequentistNet(torch.nn.Module):
    def __init__(self, num_classes, model, drop_rate=0.2, in_channels=3):
        super().__init__()

        self._base = BaseNet(model, in_channels)
        self._fc = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)

    def forward(self, x, *args, **kwargs):
        fc = self._base(x)
        logits = self._fc(fc)
        return logits


class GaussianDropoutNet(torch.nn.Module):
    def __init__(self, num_classes, model, drop_rate=0.2, in_channels=3):
        super().__init__()

        self._base = BaseNet(model, in_channels)
        self._fc = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)

        self.N = 25
        self.p = drop_rate

    def mc_gaussian_dropout(self, x):
        fc = self._base(x)
        logits = self._fc(fc)
        logits_list = gaussian_dropout(fc, logits, self.p, self._fc).unsqueeze(0)
        for i in range(self.N - 1):
            fc = self._base(x)
            logits = self._fc(fc)
            logits = gaussian_dropout(fc, logits, self.p, self._fc).unsqueeze(0)
            logits_list = torch.cat([logits_list, logits], dim=0)
        return logits_list

    def forward(self, x, mc=False, *args, **kwargs):
        if mc:
            logits_list = self.mc_gaussian_dropout(x)
            return logits_list
        else:
            fc = self._base(x)
            logits = self._fc(fc)
            logits = gaussian_dropout(fc, logits, self.p, self._fc)
            return logits


class MonteCarloDropoutNet(torch.nn.Module):
    def __init__(self, num_classes, model, drop_rate=0.2, in_channels=3):
        super().__init__()

        self._base = BaseNet(model, in_channels)
        self._fc = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)

        self.N = 25
        self.p = drop_rate

    def mc_dropout(self, x):
        fc = self._base(x)
        fc = torch.nn.functional.dropout(fc, p=self.p, training=True)
        logits_list = self._fc(fc).unsqueeze(0)
        for i in range(self.N - 1):
            fc = self._base(x)
            fc = torch.nn.functional.dropout(fc, p=self.p, training=True)
            logits = self._fc(fc).unsqueeze(0)
            logits_list = torch.cat([logits_list, logits], dim=0)
        return logits_list

    def forward(self, x, mc=False, *args, **kwargs):
        if mc:
            logits_list = self.mc_dropout(x)
            return logits_list
        else:
            fc = self._base(x)
            fc = torch.nn.functional.dropout(fc, p=self.p, training=True)
            logits = self._fc(fc)
            return logits


class BBBNet(torch.nn.Module):
    def __init__(self, num_classes, model, drop_rate=0.2, in_channels=3):
        super().__init__()

        self._base = BaseNet(model, in_channels)
        self._fc_mu = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)
        self._fc_log_var = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)

        self.N = 25
        self.p = drop_rate

    @staticmethod
    def reparameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def mc_bbb(self, x):
        fc = self._base(x)
        mu = self._fc_mu(fc)
        log_var = self._fc_log_var(fc)
        logits_list = self.reparameterize(mu, log_var).unsqueeze(0)
        for i in range(self.N - 1):
            fc = self._base(x)
            mu = self._fc_mu(fc)
            log_var = self._fc_log_var(fc)
            logits = self.reparameterize(mu, log_var).unsqueeze(0)
            logits_list = torch.cat([logits_list, logits], dim=0)
        return logits_list

    def forward(self, x, mc=False, *args, **kwargs):
        if mc:
            logits_list = self.mc_bbb(x)
            return logits_list
        else:
            fc = self._base(x)
            mu = self._fc_mu(fc)
            log_var = self._fc_log_var(fc)
            logits = self.reparameterize(mu, log_var)
            return logits


class SWAGNet(torch.nn.Module):
    def __init__(self, num_classes, model, in_channels=3):
        super().__init__()

        self._base = BaseNet(model, in_channels)
        self._fc = torch.nn.Linear(in_features=self._base.fc_in_features, out_features=num_classes)

        self._fc_w_mu = torch.zeros_like(self._fc.weight.data)
        self._fc_b_mu = torch.zeros_like(self._fc.bias.data)
        self._fc_w_var = torch.ones_like(self._fc.weight.data)
        self._fc_b_var = torch.ones_like(self._fc.bias.data)

        self.N = 25

    def train_swag(self, train_loader, optimizer_net, device):
        """
        This function performs minibatch SWA-Gaussian-diagonal for one epoch.
        """
        mu_w_list = []
        mu_b_list = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer_net.zero_grad()

            with torch.no_grad():
                fc = self._base(data)

            logits = self._fc(fc)

            loss = torch.nn.functional.cross_entropy(logits, target)
            loss.backward()
            optimizer_net.step()
            mu_w_list.append(self._fc.weight.detach().unsqueeze(0))
            mu_b_list.append(self._fc.bias.detach().unsqueeze(0))

        mu_w_list = torch.cat(mu_w_list, dim=0)
        mu_b_list = torch.cat(mu_b_list, dim=0)

        # average weights
        self._fc_w_mu = torch.mean(mu_w_list, dim=0)
        self._fc_b_mu = torch.mean(mu_b_list, dim=0)
        self._fc_w_var = (torch.mean(mu_w_list**2, dim=0) - self._fc_w_mu**2)
        self._fc_b_var = (torch.mean(mu_b_list**2, dim=0) - self._fc_b_mu**2)

        # updating batch norm statistics not necessary as only the last layer is implemented using SWAG

    @staticmethod
    def reparameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def sample(self):
        weight = self.reparameterize(self._fc_w_mu, (self._fc_w_var + 1e-6).log())
        bias = self.reparameterize(self._fc_b_mu, (self._fc_b_var + 1e-6).log())
        return weight, bias

    def mc_swag(self, x):
        fc = self._base(x)
        w, b = self.sample()
        logits_list = torch.nn.functional.linear(fc, w, b).unsqueeze(0)
        for i in range(self.N - 1):
            fc = self._base(x)
            w, b = self.sample()
            logits = torch.nn.functional.linear(fc, w, b).unsqueeze(0)
            logits_list = torch.cat([logits_list, logits], dim=0)
        return logits_list

    def forward(self, x, train=True, mc=False, *args, **kwargs):
        if train:
            fc = self._base(x)
            logits = self._fc(fc)
            return logits
        else:
            if mc:
                logits_list = self.mc_swag(x)
                return logits_list
            else:
                fc = self._base(x)
                w, b = self.sample()
                logits = torch.nn.functional.linear(fc, w, b)
                return logits
