import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Parameter
import numpy as np
dev = 'cuda' if torch.cuda.is_available() else 'cpu'


class HHtrans(nn.Module):
    def __init__(self, v0_option):
        self.v0_option = v0_option
        super(HHtrans, self).__init__()

    def forward(self, i, v, s, H):
        if self.v0_option:
            K = v.shape[1] 
            vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1))  
            norm_sq = torch.sum(v * v, 1)  
            norm_sq = norm_sq.unsqueeze(-1).unsqueeze(-1).expand(norm_sq.size(0), K, K)  
            H[str(i)] = torch.eye(K, K).to(dev) - 2 * vvT / norm_sq  
            s_new = torch.bmm(H[str(i)], s.unsqueeze(2)).squeeze(2)
        else: 
            K = v.shape[1]
            vvT = torch.mm(v.T, v) 
            norm_sq = torch.sum(v * v, 1) 
            H[str(i)] = torch.eye(K, K).to(dev) - 2 * vvT / norm_sq 
            s_new = torch.mm(H[str(i)], s.T).T

        return s_new


class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=0, dilation=1, bias=True, groups=1):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias)

    def forward(self, x):
        return F.conv2d(x, self.weight,
                        self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


class Linear(nn.Module):
    def __init__(self, in_features, out_features, args, v0_dim=None, alpha=1., after_cnn=False, last_number_channel=None):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.args = args
        self.v0_dim = v0_dim 
        self.last_number_channel = last_number_channel
        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))

        self.v0_option = args.v0_option
        self.after_cnn = after_cnn

        if self.args.method == 'ours' and self.after_cnn:
            log_alpha = (torch.ones(self.last_number_channel) * alpha).log() # last_number_channel = 64
        else:
            log_alpha = (torch.ones(self.in_features) * alpha).log() 
        self.log_alpha_ = nn.Parameter(log_alpha) 

        self.num_HH = args.num_HH
        self.HHtrans = HHtrans(self.v0_option)

        self.v_layers = nn.ModuleList()
        if self.after_cnn:
            self.v_layers.append(nn.Linear(self.in_features, self.last_number_channel))
            for i in range(1, self.num_HH):
                self.v_layers.append(nn.Linear(self.last_number_channel, self.last_number_channel))
        else:
            for i in range(0, self.num_HH):
                if args.lowrank:
                    self.v_layers.append(nn.Sequential(nn.Linear(in_features, 10),
                                                        nn.ReLU(),
                                                        nn.Linear(10, in_features)))
                else:
                    self.v_layers.append(nn.Linear(in_features, in_features))
        
        self.reset_parameters()

    
    def reset_parameters(self):
        init.kaiming_uniform_(self.W, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    
    def tile(self, A, dim, n_tile):
        init_dim = A.size(dim)
        repeat_idx = [1]*A.dim()
        repeat_idx[dim] = n_tile
        A = A.repeat(*(repeat_idx))
        order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
        return torch.index_select(A, dim, order_index.to(dev))


    def q_s_HHtrans(self, s, v0, H):
        v = {}
        if self.num_HH > 0:
            v['0'] = v0
            for i in range(0, self.num_HH):
                v[str(i + 1)] = self.v_layers[i](v[str(i)])
                v[str(i + 1)] = F.leaky_relu(v[str(i + 1)])
                s[str(i + 1)] = self.HHtrans(i + 1, v[str(i + 1)], s[str(i)], H)
            return s[str(self.num_HH)]
        return s['0']

    def forward(self, x, input=None):      
        alpha = self.log_alpha_.exp()
        r = {}
        s = {}
        H = {}
        if self.args.method == 'ours' and self.after_cnn:
            r['0'] = torch.sqrt(alpha) * torch.randn(x.shape[0], self.last_number_channel).to(dev)
        else:
            r['0'] = torch.sqrt(alpha) * torch.randn(x.size()).to(dev)
        s['0'] = 1 + r['0']

        if self.args.method == 'ours' and self.num_HH > 0:
            if self.v0_option:
                v0 = x
                r_K = self.q_s_HHtrans(r, v0, H)
                s_K = 1 + r_K
            else:
                v0 = torch.mean(x, 0)
                r_K = self.q_s_HHtrans(r, v0.unsqueeze(0), H)
                s_K = 1 + r_K

            self.U = H[str(self.num_HH)]
            for i in reversed(range(1, self.num_HH)):
                if self.v0_option:
                    self.U = torch.bmm(self.U, H[str(i)])
                else:
                    self.U = torch.mm(self.U, H[str(i)])

            if self.args.method == 'ours' and self.after_cnn:
                x = x.reshape(x.shape[0], 64, 7, 7)
                X_noised = x * s_K.unsqueeze(-1).unsqueeze(-1)
                X_noised = X_noised.reshape(x.shape[0], -1)
                # s_K = self.tile(s_K, 1, 49)
            else:
                X_noised = x * s_K
        else:
            X_noised = x * s['0']

        activation = F.linear(X_noised, self.W)
        return activation + self.bias

    def kl_reg(self):
        alpha = self.log_alpha_.exp()

        if self.num_HH > 0 and self.args.method == "ours":
            M, K = self.U.shape[:2]
            if not self.v0_option:
                M = 1          
            kl = torch.log((1 + torch.sum(alpha * self.U ** 2, dim=-1)) / alpha.unsqueeze(0).expand(M,K))
            kl = torch.sum(kl, dim=-1)
            kl = kl * self.out_features
            if self.after_cnn:
                kl = kl * 49
            return kl.mean() / 2 


class LeNetVSD(nn.Module):
    def __init__(self, args):
        super(LeNetVSD, self).__init__()
        self.conv1 = Conv2d(3, 32, stride=2)
        self.conv2 = Conv2d(32, 64, stride=2)
        v0_dim = None
        self.args = args

        self.l1 = Linear(64 * 7 * 7, 128, args, v0_dim=v0_dim, alpha=args.droprate / (1 - args.droprate), after_cnn=True, last_number_channel=64)
        self.l2 = Linear(128, 10, args, v0_dim=v0_dim, alpha=args.droprate / (1 - args.droprate))
        self._init_weights()

    def forward(self, input):
        out = F.relu(self.conv1(input.to(dev)))
        out = F.relu(self.conv2(out))
        out = out.reshape(out.shape[0], -1)
        out = F.relu(self.l1(out))
        return self.l2(out)

    def _init_weights(self):
        for layer in self.children():
            if hasattr(layer, 'weight'): nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))


class Learner(nn.Module):
    def __init__(self, net, num_batches, num_samples):
        super(Learner, self).__init__()
        self.num_batches = num_batches
        self.num_samples = num_samples
        self.net = net

    def forward(self, input, target, kl_weight=0.1):
        assert not target.requires_grad
        kl = 0.0
        alpha_ = []
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
            if hasattr(module, 'log_alpha_'):
                alpha_.append(module.log_alpha_.exp())
        kl = kl / self.num_samples
        cross_entropy = F.cross_entropy(input, target)
        elbo = - cross_entropy - kl

        return cross_entropy, cross_entropy + kl_weight * kl, elbo, alpha_


