import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Parameter
import numpy as np

dev = 'cuda' if torch.cuda.is_available() else 'cpu'


class HHtrans(nn.Module):
    def __init__(self, v0_option):
        self.v0_option = v0_option
        super(HHtrans, self).__init__()

    def forward(self, i, v, s, H):
        if self.v0_option:
            K = v.shape[1]
            vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1))
            norm_sq = torch.sum(v * v, 1)
            norm_sq = norm_sq.unsqueeze(-1).unsqueeze(-1).expand(norm_sq.size(0), K, K)
            H[str(i)] = torch.eye(K, K).to(dev) - 2 * vvT / norm_sq
            s_new = torch.bmm(H[str(i)], s.unsqueeze(2)).squeeze(2)
        else:
            K = v.shape[1]
            vvT = torch.mm(v.T, v)
            norm_sq = torch.sum(v * v, 1)
            H[str(i)] = torch.eye(K, K).to(dev) - 2 * vvT / norm_sq
            s_new = torch.mm(H[str(i)], s.T).T

        return s_new

class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, in_features=1, out_features=1, kernel_size=3, stride=1,
                 padding=0, dilation=1, bias=True, groups=1, num_HH=args.num_HH,
                 alpha_init=args.droprate/(1 - args.droprate)):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias)

        log_alpha_init = (torch.ones(in_channels) * alpha_init).log()
        self.log_alpha_ = nn.Parameter(log_alpha_init)
        self.num_HH = num_HH
        self.in_channels = in_channels
        self.in_features = in_features
        self.out_features = out_features

        self.HHTrans = HHTrans(False)
        self.v_layers = nn.ModuleList()

        self.v_layers.append(nn.Linear(in_features, in_channels))
        for i in range(0, self.num_HH - 1):
            self.v_layers.append(nn.Linear(in_channels, in_channels))

    def q_s_HHTrans(self, s, x, H):
        v = {}
        if self.num_HH > 0:
            v['0'] = x
            for i in range(0, self.num_HH):
                v[str(i + 1)] = self.v_layers[i](v[str(i)])
                v[str(i + 1)] = F.leaky_relu(v[str(i + 1)])
                s[str(i + 1)] = self.HHTrans(i + 1, v[str(i + 1)], s[str(i)], H)
            return s[str(self.num_HH)]
        return s['0']

    def forward(self, x):
        self.H, self.W = x.shape[2], x.shape[3]
        alpha = self.log_alpha_.exp()
        r = {}
        s = {}
        H = {}
        r['0'] = torch.sqrt(alpha) * torch.randn(x.shape[0], self.in_channels).to(dev)
        s['0'] = 1 + r['0']

        if self.num_HH > 0:
            v0 = x.reshape(x.shape[0], -1)
            r_K = self.q_s_HHTrans(r, v0, H)
            s_K = 1 + r_K
            self.U = H[str(self.num_HH)]
            for i in reversed(range(1, self.num_HH)):
                self.U = torch.bmm(self.U, H[str(i)])
            X_noised = x * s_K.unsqueeze(-1).unsqueeze(-1)
        else:
            X_noised = x * s['0'].unsqueeze(-1).unsqueeze(-1)

        return F.conv2d(X_noised, self.weight,
                        self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def kl_reg(self):
        if self.num_HH == 0:
            c1 = 1.16145124
            c2 = -1.50204118
            c3 = 0.58629921
            alpha = self.log_alpha_.exp()
            negative_kl = 0.5 * self.log_alpha_ + c1 * alpha + c2 * alpha ** 2 + c3 * alpha ** 3
            kl = -negative_kl
            kl *= (self.H * self.W)
            return kl.sum()

        if self.num_HH > 0:
            alpha = self.log_alpha_.exp()
            M, K = self.U.shape[:2]
            kl = torch.log((1 + torch.sum(alpha * self.U ** 2, dim=-1)) / alpha.unsqueeze(0).expand(M, K))
            kl = torch.sum(kl, dim=-1)
            kl = kl * self.H * self.W
            return kl.mean() / 2

class AlexNet(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        if args.dataset == 'stl10':
            self.conv1 = Conv2d(3, 64, kernel_size=11, stride=4, padding=5,
                                in_features=3 * 96 * 96
                                , out_features=64 * 12 * 12)
            self.conv2 = Conv2d(64, 192, kernel_size=5, padding=2, in_features=64 * 12 * 12, out_features=192 * 6 * 6)
            self.conv3 = Conv2d(192, 384, kernel_size=3, padding=1, in_features=192 * 6 * 6, out_features=384 * 6 * 6,
                                )
            self.conv4 = Conv2d(384, 256, kernel_size=3, padding=1, in_features=384 * 6 * 6, out_features=256 * 6 * 6,
                                )
            self.conv5 = Conv2d(256, 256, kernel_size=3, padding=1, in_features=256 * 6 * 6, out_features=256 * 3 * 3,
                                )
            hidden_size = 2304
        else:
            self.conv1 = Conv2d(3, 64, kernel_size=11, stride=4, padding=5,
                                in_features=3 * 32 * 32
                                , out_features=64 * 4 * 4)
            self.conv2 = Conv2d(64, 192, kernel_size=5, padding=2, in_features=64 * 4 * 4, out_features=192 * 2 * 2)
            self.conv3 = Conv2d(192, 384, kernel_size=3, padding=1, in_features=192 * 2 * 2, out_features=384 * 2 * 2)
            self.conv4 = Conv2d(384, 256, kernel_size=3, padding=1, in_features=384 * 2 * 2, out_features=256 * 2 * 2)
            self.conv5 = Conv2d(256, 256, kernel_size=3, padding=1, in_features=256 * 2 * 2, out_features=256 * 1 * 1)
            hidden_size = 256

        self.fc = Linear(hidden_size, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x), inplace=True)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x), inplace=True)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv3(x), inplace=True)
        x = F.relu(self.conv4(x), inplace=True)
        x = F.relu(self.conv5(x), inplace=True)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


def alexnet(**kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    """
    model = AlexNet(**kwargs)
    return model


class Linear(nn.Module):
    def __init__(self, in_features, out_features, args, v0_dim=None, alpha_init=args.droprate/(1 - args.droprate),
                 last_number_channel=None):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.args = args
        self.v0_dim = v0_dim
        self.last_number_channel = last_number_channel
        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))

        self.v0_option = args.v0_option

        log_alpha = (torch.ones(self.in_features) * alpha_init).log()
        self.log_alpha_ = nn.Parameter(log_alpha)

        self.num_HH = args.num_HH
        self.HHtrans = HHtrans(self.v0_option)

        self.v_layers = nn.ModuleList()

        for i in range(0, self.num_HH):
            if args.lowrank:
                self.v_layers.append(nn.Sequential(nn.Linear(in_features, 10),
                                                   nn.ReLU(),
                                                   nn.Linear(10, in_features)))
            else:
                self.v_layers.append(nn.Linear(in_features, in_features))

        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.W, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def q_s_HHtrans(self, s, v0, H):
        v = {}
        if self.num_HH > 0:
            v['0'] = v0
            for i in range(0, self.num_HH):
                v[str(i + 1)] = self.v_layers[i](v[str(i)])
                v[str(i + 1)] = F.leaky_relu(v[str(i + 1)])
                s[str(i + 1)] = self.HHtrans(i + 1, v[str(i + 1)], s[str(i)], H)
            return s[str(self.num_HH)]
        return s['0']

    def forward(self, x):
        alpha = self.log_alpha_.exp()
        r = {}
        s = {}
        H = {}

        r['0'] = torch.sqrt(alpha) * torch.randn(x.size()).to(dev)
        s['0'] = 1 + r['0']

        if self.args.method == 'ours' and self.num_HH > 0:
            if self.v0_option:
                v0 = x
                r_K = self.q_s_HHtrans(r, v0, H)
                s_K = 1 + r_K
            else:
                v0 = torch.mean(x, 0)
                r_K = self.q_s_HHtrans(r, v0.unsqueeze(0), H)
                s_K = 1 + r_K

            self.U = H[str(self.num_HH)]
            for i in reversed(range(1, self.num_HH)):
                if self.v0_option:
                    self.U = torch.bmm(self.U, H[str(i)])
                else:
                    self.U = torch.mm(self.U, H[str(i)])

            X_noised = x * s_K
        else:
            X_noised = x * s['0']

        activation = F.linear(X_noised, self.W)
        return activation + self.bias

    def kl_reg(self):
        alpha = self.log_alpha_.exp()

        if self.args.method == 'ours' and self.num_HH > 0:
            M, K = self.U.shape[:2]
            if not self.v0_option:
                M = 1
            kl = torch.log((1 + torch.sum(alpha * self.U ** 2, dim=-1)) / alpha.unsqueeze(0).expand(M, K))
            kl = torch.sum(kl, dim=-1)
            kl = kl * self.out_features
            return kl.mean() / 2


class Learner(nn.Module):
    def __init__(self, net, num_batches, num_samples):
        super(Learner, self).__init__()
        self.num_batches = num_batches
        self.num_samples = num_samples
        self.net = net

    def forward(self, input, target, kl_weight=0.1):
        assert not target.requires_grad
        kl = 0.0
        alpha_ = []
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
            if hasattr(module, 'log_alpha_'):
                alpha_.append(module.log_alpha_.exp())
        kl = kl / self.num_samples
        cross_entropy = F.cross_entropy(input, target)
        elbo = - cross_entropy - kl

        return cross_entropy, cross_entropy + kl_weight * kl, elbo, alpha_


