import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models


class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0,
                 output_mode='logprobs'):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]
        self.output_mode = output_mode

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, output_all_layers=False):
        out = x
        all_hidden_layers = []
        for module in [self.conv1, self.block1, self.block2, self.block3]:
            out = module(out)
            all_hidden_layers.append(out)
        # out = self.conv1(x)
        # out = self.block1(out)
        # out = self.block2(out)
        # out = self.block3(out)
        out = self.relu(self.bn1(out))
        all_hidden_layers.append(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        all_hidden_layers.append(out)
        out = self.fc(out)
        if self.output_mode == 'logprobs':
            out = F.log_softmax(out)
        elif self.output_mode == 'logits':
            pass
        if output_all_layers:
            return out, all_hidden_layers
        else:
            return out

class WideResNetEnsemble(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0,
                 num_ensemble=8):
        super(WideResNetEnsemble, self).__init__()
        self.models = nn.ModuleList()
        self.num_ensemble = num_ensemble
        for _ in range(num_ensemble):
            model = WideResNet(depth, num_classes, widen_factor=widen_factor, dropRate=dropRate)
            self.models.append(model)

    def forward(self, x):
        probs = [torch.exp(model(x)) for model in self.models]
        return torch.log(torch.mean(torch.stack(probs), dim=0))




class TempCNN(nn.Module):
    def __init__(self, width, num_layers, in_channels=3, num_classes=1,
                 min_temp=1.0, train_fc_only=False):
        super(TempCNN, self).__init__()
        layers = []
        for i in range(num_layers):
            conv2d = nn.Conv2d(in_channels, width, kernel_size=3, padding=1)
            layers += [conv2d]
            in_channels = width
            if train_fc_only:
                conv2d.requires_grad_(False)
        self.conv_layers = nn.ModuleList(layers)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.fc = nn.Linear(width*7*7, num_classes)
        self.min_temp = min_temp
        with torch.no_grad():
            # self.fc.weight.fill_(0.)
            self.fc.bias.fill_(1.0 - self.min_temp)

    def forward(self, x, **kwargs):
        for layer in self.conv_layers:
            x = layer(x)
            x = F.relu(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.min_temp + F.relu6(self.fc(x))
        return x

class TempConst(nn.Module):
    def __init__(self, min_temp=1.0):
        super(TempConst, self).__init__()
        self.min_temp = min_temp
        self.temp = nn.Parameter(torch.tensor([1.0 - self.min_temp]))

    def forward(self, x):
        return (self.min_temp + F.relu6(self.temp)) * torch.ones((x.shape[0], 1)).to(self.temp.device)


class TempResNet(nn.Module):
    def __init__(self, min_temp=1.0, pretrained=False, train_fc_only=False, zero_init_residual=False):
        super(TempResNet, self).__init__()
        self.min_temp = min_temp
        self.model = torchvision.models.resnet18(pretrained=pretrained, num_classes=(1000 if pretrained else 1),
                                                 zero_init_residual=zero_init_residual)
        if pretrained:
            self.model.fc = nn.Linear(512, 1)
        with torch.no_grad():
            self.model.fc.bias.fill_(1.0 - self.min_temp)
        if train_fc_only:
            self.model.requires_grad_(False)
            self.model.fc.requires_grad_(True)

    def forward(self, x):
        return self.min_temp + F.relu6(self.model(x))


class TempLinearOnReps(nn.Module):
    def __init__(self, inds=[], in_dims=[], min_temp=1.0):
        super(TempLinearOnReps, self).__init__()
        layers = []
        for in_dim in in_dims:
            lin_layer = nn.Linear(in_dim, 1)
            layers.append(lin_layer)
        self.lin_layers = nn.ModuleList(layers)
        self.min_temp = min_temp
        self.inds = inds
        with torch.no_grad():
            [lin_layer.bias.fill_(1.0 - self.min_temp) for lin_layer in self.lin_layers]

    def forward(self, all_layers):
        out = 0.
        for i in range(len(self.inds)):
            out += self.lin_layers[i](all_layers[self.inds[i]].flatten(start_dim=1))
        return self.min_temp + F.relu6(out / len(self.inds))


class TempNNOnReps(nn.Module):
    def __init__(self, inds=[], in_dims=[], depth=2, width=512, min_temp=1.0):
        super(TempNNOnReps, self).__init__()
        nets = []
        assert(depth >= 2)
        for in_dim in in_dims:
            mid_layers = []
            for _ in range(depth - 2):
                mid_layers += [nn.Linear(width, width), nn.ReLU()]
            net = nn.Sequential(
                nn.Linear(in_dim, width),
                nn.ReLU(),
                *mid_layers,
                nn.Linear(width, 1)
            )
            nets.append(net)
        self.nets = nn.ModuleList(nets)
        self.min_temp = min_temp
        self.inds = inds
        self.depth, self.width = depth, width
        with torch.no_grad():
            [net[-1].bias.fill_(1.0 - self.min_temp) for net in self.nets]

    def forward(self, all_layers):
        out = 0.
        for i in range(len(self.inds)):
            out += self.nets[i](all_layers[self.inds[i]].flatten(start_dim=1))
        return self.min_temp + F.relu6(out / len(self.inds))


class CalibrationMatrixScaling(nn.Module):
    def __init__(self, num_classes=100):
        super(CalibrationMatrixScaling, self).__init__()
        self.scale_layer = nn.Linear(num_classes, num_classes)
        with torch.no_grad():
            self.scale_layer.weight.data = torch.eye(num_classes)

    def forward(self, logits, all_layers):
        return F.log_softmax(self.scale_layer(logits)), None


class CalibrationMonotone(nn.Module):
    def __init__(self, min_temp=1.0,
                 num_temps=5, neg_slope=0.5, temp_init_increment=0.5,
                 output_mode='logprobs'):
        super(CalibrationMonotone, self).__init__()
        self.temps = nn.Parameter(torch.zeros(num_temps, requires_grad=True))
        self.biases = nn.Parameter(torch.zeros(num_temps, requires_grad=True))
        self.min_temp = min_temp
        self.neg_slope = neg_slope
        self.output_mode = output_mode
        with torch.no_grad():
            # [th.bias.fill_(1.0 - self.min_temp) for th in self.temp_heads]
            base_temp = 1.0 - temp_init_increment
            for i in range(self.temps.shape[0]):
                self.temps[i] = base_temp - self.min_temp
                base_temp += 0.5

    def forward(self, logits, all_layers):
        # temps = 1 x 1 x num_temps, biases = 1 x 1 x num_temps, logits = B x C x 1
        temps = self.temps.view([1, 1, -1])
        temps = self.min_temp + F.relu6(temps)
        biases = self.biases.view([1, 1, -1])
        # one-hidden-layer monotone calibration model
        out = (F.leaky_relu(logits.view([*logits.shape, 1]) - biases, negative_slope=self.neg_slope) / temps).mean(dim=2)
        if self.output_mode == 'logprobs':
            return F.log_softmax(out), temps
        else:
            return out, temps


class CalibrationNNOnReps(nn.Module):
    def __init__(self, inds=[], in_dims=[], depth=2, width=512, min_temp=1.0,
                 num_temps=5, neg_slope=0.5, temp_init_increment=0.5, output_mode='logprobs',
                 activation='leaky_relu'):
        super(CalibrationNNOnReps, self).__init__()
        reps, temp_heads, bias_heads = [], [], []
        assert(depth >= 2)
        for in_dim in in_dims:
            mid_layers = []
            for _ in range(depth - 2):
                mid_layers += [nn.Linear(width, width), nn.ReLU()]
            rep_net = nn.Sequential(
                nn.Linear(in_dim, width),
                nn.ReLU(),
                *mid_layers
            )
            temp_head = nn.Linear(width, num_temps)
            bias_head = nn.Linear(width, num_temps)
            reps.append(rep_net)
            temp_heads.append(temp_head)
            bias_heads.append(bias_head)
        self.reps = nn.ModuleList(reps)
        self.temp_heads, self.bias_heads = nn.ModuleList(temp_heads), nn.ModuleList(bias_heads)
        self.min_temp = min_temp
        self.inds = inds
        self.depth, self.width = depth, width
        self.neg_slope = neg_slope
        self.output_mode = output_mode
        self.activation = activation
        with torch.no_grad():
            # [th.bias.fill_(1.0 - self.min_temp) for th in self.temp_heads]
            base_temp = 1.0 - temp_init_increment
            for th in self.temp_heads:
                th.bias.fill_(base_temp - self.min_temp)
                base_temp += temp_init_increment

    def forward(self, logits, all_layers):
        temps, biases = 0., 0.
        for i in range(len(self.inds)):
            reps = self.reps[i](all_layers[self.inds[i]].flatten(start_dim=1))
            temps = temps + self.temp_heads[i](reps)
            biases = biases + self.bias_heads[i](reps)
        # temps = B x 1 x num_temps, biases = 1 x 1 x num_temps, logits = B x C x 1
        temps = temps.view([temps.shape[0], 1, -1])
        temps = self.min_temp + F.relu6(temps / len(self.inds))
        biases = biases.view([biases.shape[0], 1, -1])
        biases /= len(self.inds)
        # one-hidden-layer monotone calibration model
        if self.activation == 'leaky_relu':
            out = (F.leaky_relu(logits.view([*logits.shape, 1]) - biases, negative_slope=self.neg_slope) / temps).mean(dim=2)
        elif self.activation == 'softplus':
            out = F.softplus((logits.view([*logits.shape, 1]) - biases) / temps).mean(dim=2)
        if self.output_mode == 'logprobs':
            return F.log_softmax(out), temps
        else:
            return out, temps


class CalibrationNNOnRepsUnshared(nn.Module):
    def __init__(self, inds=[], in_dims=[], depth=2, width=512, min_temp=1.0,
                 num_temps=5, neg_slope=0.5, temp_init_increment=0.5, output_mode='logprobs'):
        super(CalibrationNNOnRepsUnshared, self).__init__()
        assert(depth >= 2)
        assert(len(inds) == 1 and len(in_dims) == 1)
        nets = []
        for _ in range(num_temps):
            mid_layers = []
            for _ in range(depth - 2):
                mid_layers += [nn.Linear(width, width), nn.ReLU()]
            net = nn.Sequential(
                nn.Linear(in_dims[0], width),
                nn.ReLU(),
                *mid_layers,
                nn.Linear(width, 2)
            )
            nets.append(net)
        self.nets = nn.ModuleList(nets)
        self.min_temp = min_temp
        self.inds = inds
        self.depth, self.width = depth, width
        self.neg_slope = neg_slope
        self.output_mode = output_mode
        with torch.no_grad():
            # [th.bias.fill_(1.0 - self.min_temp) for th in self.temp_heads]
            base_temp = 1.0 - temp_init_increment
            for net in self.nets:
                net[-1].bias[0].fill_(base_temp - self.min_temp)
                base_temp += temp_init_increment

    def forward(self, logits, all_layers):
        outs = [net(all_layers[self.inds[0]].flatten(start_dim=1)) for net in self.nets]
        temps = torch.stack([out[:, 0] for out in outs], dim=1)
        biases = torch.stack([out[:, 1] for out in outs], dim=1)
        # temps = B x 1 x num_temps, biases = 1 x 1 x num_temps, logits = B x C x 1
        temps = temps.view([temps.shape[0], 1, -1])
        temps = self.min_temp + F.relu6(temps)
        biases = biases.view([biases.shape[0], 1, -1])
        # one-hidden-layer monotone calibration model
        out = (F.leaky_relu(logits.view([*logits.shape, 1]) - biases, negative_slope=self.neg_slope) / temps).mean(dim=2)
        if self.output_mode == 'logprobs':
            return F.log_softmax(out), temps
        else:
            return out, temps
